Skip to content

Commit 0d33f36

Browse files
committed
Unit test
1 parent c762139 commit 0d33f36

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module contains tests for MPI utility functions."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
from unittest.mock import Mock, patch
18+
19+
import paramiko
20+
import pytest
21+
22+
from sagemaker.modules.train.container_drivers.mpi_utils import (
23+
CustomHostKeyPolicy,
24+
_can_connect,
25+
bootstrap_master_node,
26+
bootstrap_worker_node,
27+
get_mpirun_command,
28+
)
29+
30+
31+
def test_custom_host_key_policy_algo_host():
32+
"""Test CustomHostKeyPolicy accepts algo- hosts."""
33+
policy = CustomHostKeyPolicy()
34+
mock_client = Mock()
35+
mock_key = Mock()
36+
mock_key.get_name.return_value = "ssh-rsa"
37+
38+
# Should not raise exception for algo- hostname
39+
policy.missing_host_key(mock_client, "algo-1234", mock_key)
40+
41+
mock_client.get_host_keys.assert_called_once()
42+
mock_client.get_host_keys().add.assert_called_once_with("algo-1234", "ssh-rsa", mock_key)
43+
44+
45+
def test_custom_host_key_policy_invalid_host():
46+
"""Test CustomHostKeyPolicy rejects non-algo hosts."""
47+
policy = CustomHostKeyPolicy()
48+
mock_client = Mock()
49+
mock_key = Mock()
50+
51+
with pytest.raises(paramiko.SSHException) as exc_info:
52+
policy.missing_host_key(mock_client, "invalid-host", mock_key)
53+
54+
assert "Unknown host key for invalid-host" in str(exc_info.value)
55+
mock_client.get_host_keys.assert_not_called()
56+
57+
58+
@patch("paramiko.SSHClient")
59+
def test_can_connect_success(mock_ssh_client):
60+
"""Test successful SSH connection."""
61+
mock_client = Mock()
62+
mock_ssh_client.return_value = mock_client
63+
64+
assert _can_connect("algo-1234") is True
65+
mock_client.connect.assert_called_once()
66+
67+
68+
@patch("paramiko.SSHClient")
69+
def test_can_connect_failure(mock_ssh_client):
70+
"""Test SSH connection failure."""
71+
mock_client = Mock()
72+
mock_ssh_client.return_value = mock_client
73+
mock_client.connect.side_effect = Exception("Connection failed")
74+
75+
assert _can_connect("algo-1234") is False
76+
77+
78+
def test_get_mpirun_command():
79+
"""Test MPI command generation."""
80+
os.environ["SM_NETWORK_INTERFACE_NAME"] = "eth0"
81+
os.environ["SM_CURRENT_INSTANCE_TYPE"] = "ml.p4d.24xlarge"
82+
83+
command = get_mpirun_command(
84+
host_count=2,
85+
host_list=["algo-1", "algo-2"],
86+
num_processes=2,
87+
additional_options=[],
88+
entry_script_path="train.py",
89+
)
90+
91+
assert command[0] == "mpirun"
92+
assert "--host" in command
93+
assert "algo-1,algo-2" in command
94+
assert "-np" in command
95+
assert "2" in command
96+
assert f"NCCL_SOCKET_IFNAME=eth0" in " ".join(command)
97+
98+
99+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
100+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._write_file_to_host")
101+
def test_bootstrap_worker_node(mock_write, mock_connect):
102+
"""Test worker node bootstrapping."""
103+
mock_connect.return_value = True
104+
mock_write.return_value = True
105+
os.environ["SM_CURRENT_HOST"] = "algo-2"
106+
107+
with pytest.raises(TimeoutError):
108+
# Should timeout waiting for status file
109+
bootstrap_worker_node("algo-1", timeout=1)
110+
111+
mock_connect.assert_called_with("algo-1")
112+
mock_write.assert_called_with("algo-1", "/tmp/ready.algo-2")
113+
114+
115+
@patch("sagemaker.modules.train.container_drivers.mpi_utils._can_connect")
116+
def test_bootstrap_master_node(mock_connect):
117+
"""Test master node bootstrapping."""
118+
mock_connect.return_value = True
119+
120+
with pytest.raises(TimeoutError):
121+
# Should timeout waiting for ready files
122+
bootstrap_master_node(["algo-2"], timeout=1)
123+
124+
mock_connect.assert_called_with("algo-2")
125+
126+
127+
if __name__ == "__main__":
128+
pytest.main([__file__])

0 commit comments

Comments
 (0)