Skip to content

Commit 1328e69

Browse files
sage-makerbenieric
authored and
root
committed
Fix ssh host policy (aws#4966)
* Fix ssh host policy * Filter policy by algo- * Add docstring * Fix pylint * Fix docstyle summary * Unit test * Fix unit test * Change to unit test * Fix unit tests * Test comment out flaky tests * Readd the flaky tests * Remove flaky asserts * Remove flaky asserts --------- Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent 71f6d22 commit 1328e69

File tree

2 files changed

+153
-14
lines changed

2 files changed

+153
-14
lines changed

src/sagemaker/modules/train/container_drivers/mpi_utils.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from __future__ import absolute_import
1515

1616
import os
17-
import time
1817
import subprocess
19-
18+
import time
2019
from typing import List
2120

22-
from utils import logger, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable
21+
import paramiko
22+
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger
2323

2424
FINISHED_STATUS_FILE = "/tmp/done.algo-1"
2525
READY_FILE = "/tmp/ready.%s"
@@ -75,19 +75,45 @@ def start_sshd_daemon():
7575
logger.info("Started SSH daemon.")
7676

7777

78+
class CustomHostKeyPolicy(paramiko.client.MissingHostKeyPolicy):
79+
"""Class to handle host key policy for SageMaker distributed training SSH connections.
80+
81+
Example:
82+
>>> client = paramiko.SSHClient()
83+
>>> client.set_missing_host_key_policy(CustomHostKeyPolicy())
84+
>>> # Will succeed for SageMaker algorithm containers
85+
>>> client.connect('algo-1234.internal')
86+
>>> # Will raise SSHException for other unknown hosts
87+
>>> client.connect('unknown-host') # raises SSHException
88+
"""
89+
90+
def missing_host_key(self, client, hostname, key):
91+
"""Accept host keys for algo-* hostnames, reject others.
92+
93+
Args:
94+
client: The SSHClient instance
95+
hostname: The hostname attempting to connect
96+
key: The host key
97+
98+
Raises:
99+
paramiko.SSHException: If hostname doesn't match algo-* pattern
100+
"""
101+
if hostname.startswith("algo-"):
102+
client.get_host_keys().add(hostname, key.get_name(), key)
103+
return
104+
raise paramiko.SSHException(f"Unknown host key for {hostname}")
105+
106+
78107
def _can_connect(host: str, port: int = DEFAULT_SSH_PORT) -> bool:
79108
"""Check if the connection to the provided host and port is possible."""
80109
try:
81-
import paramiko
82-
83110
logger.debug("Testing connection to host %s", host)
84-
client = paramiko.SSHClient()
85-
client.load_system_host_keys()
86-
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
87-
client.connect(host, port=port)
88-
client.close()
89-
logger.info("Can connect to host %s", host)
90-
return True
111+
with paramiko.SSHClient() as client:
112+
client.load_system_host_keys()
113+
client.set_missing_host_key_policy(CustomHostKeyPolicy())
114+
client.connect(host, port=port)
115+
logger.info("Can connect to host %s", host)
116+
return True
91117
except Exception as e: # pylint: disable=W0703
92118
logger.info("Cannot connect to host %s", host)
93119
logger.debug(f"Connection failed with exception: {e}")
@@ -183,9 +209,9 @@ def validate_smddpmprun() -> bool:
183209

184210
def write_env_vars_to_file():
185211
"""Write environment variables to /etc/environment file."""
186-
with open("/etc/environment", "a") as f:
212+
with open("/etc/environment", "a", encoding="utf-8") as f:
187213
for name in os.environ:
188-
f.write("{}={}\n".format(name, os.environ.get(name)))
214+
f.write(f"{name}={os.environ.get(name)}\n")
189215

190216

191217
def get_mpirun_command(
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""MPI Utils Unit Tests."""
14+
from __future__ import absolute_import
15+
16+
import subprocess
17+
from unittest.mock import Mock, patch
18+
19+
import paramiko
20+
import pytest
21+
22+
# Mock the utils module before importing mpi_utils
23+
mock_utils = Mock()
24+
mock_utils.logger = Mock()
25+
mock_utils.SM_EFA_NCCL_INSTANCES = []
26+
mock_utils.SM_EFA_RDMA_INSTANCES = []
27+
mock_utils.get_python_executable = Mock(return_value="/usr/bin/python")
28+
29+
with patch.dict("sys.modules", {"utils": mock_utils}):
30+
from sagemaker.modules.train.container_drivers.mpi_utils import (
31+
CustomHostKeyPolicy,
32+
_can_connect,
33+
write_status_file_to_workers,
34+
)
35+
36+
TEST_HOST = "algo-1"
37+
TEST_WORKER = "algo-2"
38+
TEST_STATUS_FILE = "/tmp/test-status"
39+
40+
41+
def test_custom_host_key_policy_valid_hostname():
42+
"""Test CustomHostKeyPolicy accepts algo- prefixed hostnames."""
43+
policy = CustomHostKeyPolicy()
44+
mock_client = Mock()
45+
mock_key = Mock()
46+
mock_key.get_name.return_value = "ssh-rsa"
47+
48+
policy.missing_host_key(mock_client, "algo-1", mock_key)
49+
50+
mock_client.get_host_keys.assert_called_once()
51+
mock_client.get_host_keys().add.assert_called_once_with("algo-1", "ssh-rsa", mock_key)
52+
53+
54+
def test_custom_host_key_policy_invalid_hostname():
55+
"""Test CustomHostKeyPolicy rejects non-algo prefixed hostnames."""
56+
policy = CustomHostKeyPolicy()
57+
mock_client = Mock()
58+
mock_key = Mock()
59+
60+
with pytest.raises(paramiko.SSHException) as exc_info:
61+
policy.missing_host_key(mock_client, "invalid-1", mock_key)
62+
63+
assert "Unknown host key for invalid-1" in str(exc_info.value)
64+
mock_client.get_host_keys.assert_not_called()
65+
66+
67+
@patch("paramiko.SSHClient")
68+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
69+
def test_can_connect_success(mock_logger, mock_ssh_client):
70+
"""Test successful SSH connection."""
71+
mock_client = Mock()
72+
mock_ssh_client.return_value.__enter__.return_value = mock_client
73+
mock_client.connect.return_value = None # Successful connection
74+
75+
result = _can_connect(TEST_HOST)
76+
77+
assert result is True
78+
mock_client.load_system_host_keys.assert_called_once()
79+
mock_client.set_missing_host_key_policy.assert_called_once()
80+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
81+
82+
83+
@patch("paramiko.SSHClient")
84+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
85+
def test_can_connect_failure(mock_logger, mock_ssh_client):
86+
"""Test SSH connection failure."""
87+
mock_client = Mock()
88+
mock_ssh_client.return_value.__enter__.return_value = mock_client
89+
mock_client.connect.side_effect = paramiko.SSHException("Connection failed")
90+
91+
result = _can_connect(TEST_HOST)
92+
93+
assert result is False
94+
mock_client.load_system_host_keys.assert_called_once()
95+
mock_client.set_missing_host_key_policy.assert_called_once()
96+
mock_client.connect.assert_called_once_with(TEST_HOST, port=22)
97+
98+
99+
@patch("subprocess.run")
100+
@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger")
101+
def test_write_status_file_to_workers_failure(mock_logger, mock_run):
102+
"""Test failed status file writing to workers with retry timeout."""
103+
mock_run.side_effect = subprocess.CalledProcessError(1, "ssh")
104+
105+
with pytest.raises(TimeoutError) as exc_info:
106+
write_status_file_to_workers([TEST_WORKER], TEST_STATUS_FILE)
107+
108+
assert f"Timed out waiting for {TEST_WORKER}" in str(exc_info.value)
109+
assert mock_run.call_count > 1 # Verifies that retries occurred
110+
111+
112+
if __name__ == "__main__":
113+
pytest.main([__file__])

0 commit comments

Comments
 (0)