@@ -233,6 +233,7 @@ async def create_server(self, session_factory: TCPListenerFactory,
233
233
234
234
_VersionArg = DefTuple [BytesOrStr ]
235
235
236
+ SSHAcceptHandler = Callable [[str , int ], MaybeAwait [bool ]]
236
237
237
238
# SSH service names
238
239
_USERAUTH_SERVICE = b'ssh-userauth'
@@ -2886,10 +2887,10 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder:
2886
2887
return SSHForwarder (cast (SSHForwarder , peer ))
2887
2888
2888
2889
@async_context_manager
2889
- async def forward_local_port (self , listen_host : str ,
2890
- listen_port : int ,
2891
- dest_host : str ,
2892
- dest_port : int ) -> SSHListener :
2890
+ async def forward_local_port (
2891
+ self , listen_host : str , listen_port : int ,
2892
+ dest_host : str , dest_port : int ,
2893
+ accept_handler : Optional [ SSHAcceptHandler ] = None ) -> SSHListener :
2893
2894
"""Set up local port forwarding
2894
2895
2895
2896
This method is a coroutine which attempts to set up port
@@ -2906,10 +2907,17 @@ async def forward_local_port(self, listen_host: str,
2906
2907
The hostname or address to forward the connections to
2907
2908
:param dest_port:
2908
2909
The port number to forward the connections to
2910
+ :param accept_handler:
2911
+ A `callable` or coroutine which takes arguments of the
2912
+ original host and port of the client and decides whether
2913
+ or not to allow connection forwarding, returning `True` to
2914
+ accept the connection and begin forwarding or `False` to
2915
+ reject and close it.
2909
2916
:type listen_host: `str`
2910
2917
:type listen_port: `int`
2911
2918
:type dest_host: `str`
2912
2919
:type dest_port: `int`
2920
+ :type accept_handler: `callable` or coroutine
2913
2921
2914
2922
:returns: :class:`SSHListener`
2915
2923
@@ -2923,6 +2931,21 @@ async def tunnel_connection(
2923
2931
Tuple [SSHTCPChannel [bytes ], SSHTCPSession [bytes ]]:
2924
2932
"""Forward a local connection over SSH"""
2925
2933
2934
+ if accept_handler :
2935
+ result = accept_handler (orig_host , orig_port )
2936
+
2937
+ if inspect .isawaitable (result ):
2938
+ result = await cast (Awaitable [bool ], result )
2939
+
2940
+ if not result :
2941
+ self .logger .info ('Request for TCP forwarding from '
2942
+ '%s to %s denied by application' ,
2943
+ (orig_host , orig_port ),
2944
+ (dest_host , dest_port ))
2945
+
2946
+ raise ChannelOpenError (OPEN_ADMINISTRATIVELY_PROHIBITED ,
2947
+ 'Connection forwarding denied' )
2948
+
2926
2949
return (await self .create_connection (session_factory ,
2927
2950
dest_host , dest_port ,
2928
2951
orig_host , orig_port ))
@@ -4695,9 +4718,9 @@ async def listen_reverse_ssh(self, host: str = '',
4695
4718
** kwargs ) # type: ignore
4696
4719
4697
4720
@async_context_manager
4698
- async def forward_local_port_to_path (self , listen_host : str ,
4699
- listen_port : int ,
4700
- dest_path : str ) -> SSHListener :
4721
+ async def forward_local_port_to_path (
4722
+ self , listen_host : str , listen_port : int , dest_path : str ,
4723
+ accept_handler : Optional [ SSHAcceptHandler ] = None ) -> SSHListener :
4701
4724
"""Set up local TCP port forwarding to a remote UNIX domain socket
4702
4725
4703
4726
This method is a coroutine which attempts to set up port
@@ -4712,9 +4735,16 @@ async def forward_local_port_to_path(self, listen_host: str,
4712
4735
The port number on the local host to listen on
4713
4736
:param dest_path:
4714
4737
The path on the remote host to forward the connections to
4738
+ :param accept_handler:
4739
+ A `callable` or coroutine which takes arguments of the
4740
+ original host and port of the client and decides whether
4741
+ or not to allow connection forwarding, returning `True` to
4742
+ accept the connection and begin forwarding or `False` to
4743
+ reject and close it.
4715
4744
:type listen_host: `str`
4716
4745
:type listen_port: `int`
4717
4746
:type dest_path: `str`
4747
+ :type accept_handler: `callable` or coroutine
4718
4748
4719
4749
:returns: :class:`SSHListener`
4720
4750
@@ -4724,10 +4754,24 @@ async def forward_local_port_to_path(self, listen_host: str,
4724
4754
4725
4755
async def tunnel_connection (
4726
4756
session_factory : SSHUNIXSessionFactory [bytes ],
4727
- _orig_host : str , _orig_port : int ) -> \
4757
+ orig_host : str , orig_port : int ) -> \
4728
4758
Tuple [SSHUNIXChannel [bytes ], SSHUNIXSession [bytes ]]:
4729
4759
"""Forward a local connection over SSH"""
4730
4760
4761
+ if accept_handler :
4762
+ result = accept_handler (orig_host , orig_port )
4763
+
4764
+ if inspect .isawaitable (result ):
4765
+ result = await cast (Awaitable [bool ], result )
4766
+
4767
+ if not result :
4768
+ self .logger .info ('Request for TCP forwarding from '
4769
+ '%s to %s denied by application' ,
4770
+ (orig_host , orig_port ), dest_path )
4771
+
4772
+ raise ChannelOpenError (OPEN_ADMINISTRATIVELY_PROHIBITED ,
4773
+ 'Connection forwarding denied' )
4774
+
4731
4775
return (await self .create_unix_connection (session_factory ,
4732
4776
dest_path ))
4733
4777
@@ -5737,6 +5781,10 @@ async def _finish_port_forward(self, listen_host: str,
5737
5781
if listener is True :
5738
5782
listener = await self .forward_local_port (
5739
5783
listen_host , listen_port , listen_host , listen_port )
5784
+ elif callable (listener ):
5785
+ listener = await self .forward_local_port (
5786
+ listen_host , listen_port ,
5787
+ listen_host , listen_port , listener )
5740
5788
except OSError :
5741
5789
self .logger .debug1 ('Failed to create TCP listener' )
5742
5790
self ._report_global_response (False )
0 commit comments