Skip to content

Commit 70f65eb

Browse files
committed
Add support for an "accept handler" in connection forwarding
This commit adds support for a new accept_handler argument in the forward_local_port and forward_local_port_to_path methods in SSHClientConnection and the ability to return an accept handler in the server_requested method in SSHServer. This method receives the original host & port of the incoming forwarded connection and can return a bool to determine whether forwarding is allowed or not. Thanks go to GitHub user zgxkbtl for suggesting this feature!
1 parent 777d328 commit 70f65eb

File tree

4 files changed

+164
-10
lines changed

4 files changed

+164
-10
lines changed

asyncssh/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection
4646
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
47+
from .connection import SSHAcceptHandler
4748
from .connection import create_connection, create_server, connect, listen
4849
from .connection import connect_reverse, listen_reverse, get_server_host_key
4950
from .connection import get_server_auth_methods, run_client, run_server

asyncssh/connection.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ async def create_server(self, session_factory: TCPListenerFactory,
233233

234234
_VersionArg = DefTuple[BytesOrStr]
235235

236+
SSHAcceptHandler = Callable[[str, int], MaybeAwait[bool]]
236237

237238
# SSH service names
238239
_USERAUTH_SERVICE = b'ssh-userauth'
@@ -2886,10 +2887,10 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder:
28862887
return SSHForwarder(cast(SSHForwarder, peer))
28872888

28882889
@async_context_manager
2889-
async def forward_local_port(self, listen_host: str,
2890-
listen_port: int,
2891-
dest_host: str,
2892-
dest_port: int) -> SSHListener:
2890+
async def forward_local_port(
2891+
self, listen_host: str, listen_port: int,
2892+
dest_host: str, dest_port: int,
2893+
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
28932894
"""Set up local port forwarding
28942895
28952896
This method is a coroutine which attempts to set up port
@@ -2906,10 +2907,17 @@ async def forward_local_port(self, listen_host: str,
29062907
The hostname or address to forward the connections to
29072908
:param dest_port:
29082909
The port number to forward the connections to
2910+
:param accept_handler:
2911+
A `callable` or coroutine which takes arguments of the
2912+
original host and port of the client and decides whether
2913+
or not to allow connection forwarding, returning `True` to
2914+
accept the connection and begin forwarding or `False` to
2915+
reject and close it.
29092916
:type listen_host: `str`
29102917
:type listen_port: `int`
29112918
:type dest_host: `str`
29122919
:type dest_port: `int`
2920+
:type accept_handler: `callable` or coroutine
29132921
29142922
:returns: :class:`SSHListener`
29152923
@@ -2923,6 +2931,21 @@ async def tunnel_connection(
29232931
Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]:
29242932
"""Forward a local connection over SSH"""
29252933

2934+
if accept_handler:
2935+
result = accept_handler(orig_host, orig_port)
2936+
2937+
if inspect.isawaitable(result):
2938+
result = await cast(Awaitable[bool], result)
2939+
2940+
if not result:
2941+
self.logger.info('Request for TCP forwarding from '
2942+
'%s to %s denied by application',
2943+
(orig_host, orig_port),
2944+
(dest_host, dest_port))
2945+
2946+
raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED,
2947+
'Connection forwarding denied')
2948+
29262949
return (await self.create_connection(session_factory,
29272950
dest_host, dest_port,
29282951
orig_host, orig_port))
@@ -4695,9 +4718,9 @@ async def listen_reverse_ssh(self, host: str = '',
46954718
**kwargs) # type: ignore
46964719

46974720
@async_context_manager
4698-
async def forward_local_port_to_path(self, listen_host: str,
4699-
listen_port: int,
4700-
dest_path: str) -> SSHListener:
4721+
async def forward_local_port_to_path(
4722+
self, listen_host: str, listen_port: int, dest_path: str,
4723+
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
47014724
"""Set up local TCP port forwarding to a remote UNIX domain socket
47024725
47034726
This method is a coroutine which attempts to set up port
@@ -4712,9 +4735,16 @@ async def forward_local_port_to_path(self, listen_host: str,
47124735
The port number on the local host to listen on
47134736
:param dest_path:
47144737
The path on the remote host to forward the connections to
4738+
:param accept_handler:
4739+
A `callable` or coroutine which takes arguments of the
4740+
original host and port of the client and decides whether
4741+
or not to allow connection forwarding, returning `True` to
4742+
accept the connection and begin forwarding or `False` to
4743+
reject and close it.
47154744
:type listen_host: `str`
47164745
:type listen_port: `int`
47174746
:type dest_path: `str`
4747+
:type accept_handler: `callable` or coroutine
47184748
47194749
:returns: :class:`SSHListener`
47204750
@@ -4724,10 +4754,24 @@ async def forward_local_port_to_path(self, listen_host: str,
47244754

47254755
async def tunnel_connection(
47264756
session_factory: SSHUNIXSessionFactory[bytes],
4727-
_orig_host: str, _orig_port: int) -> \
4757+
orig_host: str, orig_port: int) -> \
47284758
Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]:
47294759
"""Forward a local connection over SSH"""
47304760

4761+
if accept_handler:
4762+
result = accept_handler(orig_host, orig_port)
4763+
4764+
if inspect.isawaitable(result):
4765+
result = await cast(Awaitable[bool], result)
4766+
4767+
if not result:
4768+
self.logger.info('Request for TCP forwarding from '
4769+
'%s to %s denied by application',
4770+
(orig_host, orig_port), dest_path)
4771+
4772+
raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED,
4773+
'Connection forwarding denied')
4774+
47314775
return (await self.create_unix_connection(session_factory,
47324776
dest_path))
47334777

@@ -5737,6 +5781,10 @@ async def _finish_port_forward(self, listen_host: str,
57375781
if listener is True:
57385782
listener = await self.forward_local_port(
57395783
listen_host, listen_port, listen_host, listen_port)
5784+
elif callable(listener):
5785+
listener = await self.forward_local_port(
5786+
listen_host, listen_port,
5787+
listen_host, listen_port, listener)
57405788
except OSError:
57415789
self.logger.debug1('Failed to create TCP listener')
57425790
self._report_global_response(False)

asyncssh/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
if TYPE_CHECKING:
3333
# pylint: disable=cyclic-import
34-
from .connection import SSHServerConnection
34+
from .connection import SSHServerConnection, SSHAcceptHandler
3535
from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel
3636
from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession
3737

@@ -45,7 +45,7 @@
4545
_NewUNIXSession = Union[bool, 'SSHUNIXSession', SSHSocketSessionFactory,
4646
Tuple['SSHUNIXChannel', 'SSHUNIXSession'],
4747
Tuple['SSHUNIXChannel', SSHSocketSessionFactory]]
48-
_NewListener = Union[bool, SSHListener]
48+
_NewListener = Union[bool, 'SSHAcceptHandler', SSHListener]
4949

5050

5151
class SSHServer:

tests/test_forward.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,18 @@ async def server_requested(self, listen_host, listen_port):
183183
return listen_host != 'fail'
184184

185185

186+
class _TCPAcceptHandlerServer(Server):
187+
"""Server for testing forwarding accept handler"""
188+
189+
async def server_requested(self, listen_host, listen_port):
190+
"""Handle a request to create a new socket listener"""
191+
192+
def accept_handler(_orig_host: str, _orig_port: int) -> bool:
193+
return True
194+
195+
return accept_handler
196+
197+
186198
class _UNIXConnectionServer(Server):
187199
"""Server for testing direct and forwarded UNIX domain connections"""
188200

@@ -594,6 +606,39 @@ async def test_forward_local_port(self):
594606
await self._check_local_connection(listener.get_port(),
595607
delay=0.1)
596608

609+
@asynctest
610+
async def test_forward_local_port_accept_handler(self):
611+
"""Test forwarding of a local port with an accept handler"""
612+
613+
def accept_handler(_orig_host: str, _orig_port: int) -> bool:
614+
return True
615+
616+
async with self.connect() as conn:
617+
async with conn.forward_local_port('', 0, '', 7,
618+
accept_handler) as listener:
619+
await self._check_local_connection(listener.get_port(),
620+
delay=0.1)
621+
622+
@asynctest
623+
async def test_forward_local_port_accept_handler_denial(self):
624+
"""Test forwarding of a local port with an accept handler denial"""
625+
626+
async def accept_handler(_orig_host: str, _orig_port: int) -> bool:
627+
return False
628+
629+
async with self.connect() as conn:
630+
async with conn.forward_local_port('', 0, '', 7,
631+
accept_handler) as listener:
632+
listen_port = listener.get_port()
633+
634+
reader, writer = await asyncio.open_connection('127.0.0.1',
635+
listen_port)
636+
637+
self.assertEqual((await reader.read()), b'')
638+
639+
writer.close()
640+
await maybe_wait_closed(writer)
641+
597642
@unittest.skipIf(sys.platform == 'win32',
598643
'skip UNIX domain socket tests on Windows')
599644
@asynctest
@@ -855,6 +900,33 @@ async def test_listener_close_on_conn_close(self):
855900
await listener.wait_closed()
856901

857902

903+
class _TestTCPForwardingAcceptHandler(_CheckForwarding):
904+
"""Unit tests for TCP forwarding with accept handler"""
905+
906+
@classmethod
907+
async def start_server(cls):
908+
"""Start an SSH server which supports TCP connection forwarding"""
909+
910+
return await cls.create_server(
911+
_TCPAcceptHandlerServer, authorized_client_keys='authorized_keys')
912+
913+
@asynctest
914+
async def test_forward_remote_port_accept_handler(self):
915+
"""Test forwarding of a remote port with accept handler"""
916+
917+
server = await asyncio.start_server(echo, None, 0,
918+
family=socket.AF_INET)
919+
server_port = server.sockets[0].getsockname()[1]
920+
921+
async with self.connect() as conn:
922+
async with conn.forward_remote_port(
923+
'', 0, '127.0.0.1', server_port) as listener:
924+
await self._check_local_connection(listener.get_port())
925+
926+
server.close()
927+
await server.wait_closed()
928+
929+
858930
class _TestAsyncTCPForwarding(_TestTCPForwarding):
859931
"""Unit tests for AsyncSSH TCP connection forwarding with async return"""
860932

@@ -999,6 +1071,39 @@ async def test_forward_local_path(self):
9991071

10001072
os.remove('local')
10011073

1074+
@asynctest
1075+
async def test_forward_local_port_to_path_accept_handler(self):
1076+
"""Test forwarding of port to UNIX path with accept handler"""
1077+
1078+
def accept_handler(_orig_host: str, _orig_port: int) -> bool:
1079+
return True
1080+
1081+
async with self.connect() as conn:
1082+
async with conn.forward_local_port_to_path(
1083+
'', 0, '/echo', accept_handler) as listener:
1084+
await self._check_local_connection(listener.get_port(),
1085+
delay=0.1)
1086+
1087+
@asynctest
1088+
async def test_forward_local_port_to_path_accept_handler_denial(self):
1089+
"""Test forwarding of port to UNIX path with accept handler denial"""
1090+
1091+
async def accept_handler(_orig_host: str, _orig_port: int) -> bool:
1092+
return False
1093+
1094+
async with self.connect() as conn:
1095+
async with conn.forward_local_port_to_path(
1096+
'', 0, '/echo', accept_handler) as listener:
1097+
listen_port = listener.get_port()
1098+
1099+
reader, writer = await asyncio.open_connection('127.0.0.1',
1100+
listen_port)
1101+
1102+
self.assertEqual((await reader.read()), b'')
1103+
1104+
writer.close()
1105+
await maybe_wait_closed(writer)
1106+
10021107
@asynctest
10031108
async def test_forward_local_port_to_path(self):
10041109
"""Test forwarding of a local port to a remote UNIX domain socket"""

0 commit comments

Comments
 (0)