@@ -19,6 +19,8 @@ package proxy
19
19
import (
20
20
"bytes"
21
21
"crypto/rand"
22
+ "errors"
23
+ "fmt"
22
24
"io"
23
25
"net"
24
26
"net/http"
@@ -48,7 +50,6 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
48
50
defer close (streamChan )
49
51
stopServerChan := make (chan struct {})
50
52
defer close (stopServerChan )
51
- // Create fake upstream SPDY server.
52
53
spdyServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
53
54
_ , err := httpstream .Handshake (req , w , []string {constants .PortForwardV1Name })
54
55
require .NoError (t , err )
@@ -107,6 +108,120 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
107
108
assert .Equal (t , randomData , actual , "error validating tunneled random data" )
108
109
}
109
110
111
+ func TestTunnelingResponseWriter_Hijack (t * testing.T ) {
112
+ // Regular hijack returns connection, nil bufio, and no error.
113
+ trw := & tunnelingResponseWriter {conn : & mockConn {}}
114
+ assert .False (t , trw .hijacked , "hijacked field starts false before Hijack()" )
115
+ assert .False (t , trw .written , "written field startes false before Hijack()" )
116
+ actual , bufio , err := trw .Hijack ()
117
+ assert .NoError (t , err , "Hijack() does not return error" )
118
+ assert .NotNil (t , actual , "conn returned from Hijack() is not nil" )
119
+ assert .Nil (t , bufio , "bufio returned from Hijack() is always nil" )
120
+ assert .True (t , trw .hijacked , "hijacked field becomes true after Hijack()" )
121
+ assert .False (t , trw .written , "written field stays false after Hijack()" )
122
+ // Hijacking after writing to response writer is an error.
123
+ trw = & tunnelingResponseWriter {written : true }
124
+ _ , _ , err = trw .Hijack ()
125
+ assert .Error (t , err , "Hijack after writing to response writer is error" )
126
+ assert .True (t , strings .Contains (err .Error (), "connection has already been written to" ))
127
+ // Hijacking after already hijacked is an error.
128
+ trw = & tunnelingResponseWriter {hijacked : true }
129
+ _ , _ , err = trw .Hijack ()
130
+ assert .Error (t , err , "Hijack after writing to response writer is error" )
131
+ assert .True (t , strings .Contains (err .Error (), "connection has already been hijacked" ))
132
+ }
133
+
134
+ func TestTunnelingResponseWriter_DelegateResponseWriter (t * testing.T ) {
135
+ // Validate Header() for delegate response writer.
136
+ expectedHeader := http.Header {}
137
+ expectedHeader .Set ("foo" , "bar" )
138
+ trw := & tunnelingResponseWriter {w : & mockResponseWriter {header : expectedHeader }}
139
+ assert .Equal (t , expectedHeader , trw .Header (), "" )
140
+ // Validate Write() for delegate response writer.
141
+ expectedWrite := []byte ("this is a test write string" )
142
+ assert .False (t , trw .written , "written field is before Write()" )
143
+ _ , err := trw .Write (expectedWrite )
144
+ assert .NoError (t , err , "No error expected after Write() on tunneling response writer" )
145
+ assert .True (t , trw .written , "written field is set after writing to tunneling response writer" )
146
+ // Writing to response writer after hijacked is an error.
147
+ trw .hijacked = true
148
+ _ , err = trw .Write (expectedWrite )
149
+ assert .Error (t , err , "Writing to ResponseWriter after Hijack() is an error" )
150
+ assert .True (t , errors .Is (err , http .ErrHijacked ), "Hijacked error returned if writing after hijacked" )
151
+ // Validate WriteHeader().
152
+ trw = & tunnelingResponseWriter {w : & mockResponseWriter {}}
153
+ expectedStatusCode := 201
154
+ assert .False (t , trw .written , "Written field originally false in delegate response writer" )
155
+ trw .WriteHeader (expectedStatusCode )
156
+ assert .Equal (t , expectedStatusCode , trw .w .(* mockResponseWriter ).statusCode , "Expected written status code is correct" )
157
+ assert .True (t , trw .written , "Written field set to true after writing delegate response writer" )
158
+ // Response writer already written to does not write status.
159
+ trw = & tunnelingResponseWriter {w : & mockResponseWriter {}}
160
+ trw .written = true
161
+ trw .WriteHeader (expectedStatusCode )
162
+ assert .Equal (t , 0 , trw .w .(* mockResponseWriter ).statusCode , "No status code for previously written response writer" )
163
+ // Hijacked response writer does not write status.
164
+ trw = & tunnelingResponseWriter {w : & mockResponseWriter {}}
165
+ trw .hijacked = true
166
+ trw .WriteHeader (expectedStatusCode )
167
+ assert .Equal (t , 0 , trw .w .(* mockResponseWriter ).statusCode , "No status code written to hijacked response writer" )
168
+ assert .False (t , trw .written , "Hijacked response writer does not write status" )
169
+ // Writing "101 Switching Protocols" status is an error, since it should happen via hijacked connection.
170
+ trw = & tunnelingResponseWriter {w : & mockResponseWriter {header : http.Header {}}}
171
+ trw .WriteHeader (http .StatusSwitchingProtocols )
172
+ assert .Equal (t , http .StatusInternalServerError , trw .w .(* mockResponseWriter ).statusCode , "Internal server error written" )
173
+ }
174
+
175
+ func TestTunnelingWebsocketUpgraderConn_LocalRemoteAddress (t * testing.T ) {
176
+ expectedLocalAddr := & net.TCPAddr {
177
+ IP : net .IPv4 (127 , 0 , 0 , 1 ),
178
+ Port : 80 ,
179
+ }
180
+ expectedRemoteAddr := & net.TCPAddr {
181
+ IP : net .IPv4 (127 , 0 , 0 , 2 ),
182
+ Port : 443 ,
183
+ }
184
+ tc := & tunnelingWebsocketUpgraderConn {
185
+ conn : & mockConn {
186
+ localAddr : expectedLocalAddr ,
187
+ remoteAddr : expectedRemoteAddr ,
188
+ },
189
+ }
190
+ assert .Equal (t , expectedLocalAddr , tc .LocalAddr (), "LocalAddr() returns expected TCPAddr" )
191
+ assert .Equal (t , expectedRemoteAddr , tc .RemoteAddr (), "RemoteAddr() returns expected TCPAddr" )
192
+ // Connection nil, returns empty address
193
+ tc .conn = nil
194
+ assert .Equal (t , noopAddr {}, tc .LocalAddr (), "nil connection, LocalAddr() returns noopAddr" )
195
+ assert .Equal (t , noopAddr {}, tc .RemoteAddr (), "nil connection, RemoteAddr() returns noopAddr" )
196
+ // Validate the empty strings from noopAddr
197
+ assert .Equal (t , "" , noopAddr {}.Network (), "noopAddr Network() returns empty string" )
198
+ assert .Equal (t , "" , noopAddr {}.String (), "noopAddr String() returns empty string" )
199
+ }
200
+
201
+ func TestTunnelingWebsocketUpgraderConn_SetDeadline (t * testing.T ) {
202
+ tc := & tunnelingWebsocketUpgraderConn {conn : & mockConn {}}
203
+ expected := time .Now ()
204
+ assert .Nil (t , tc .SetDeadline (expected ), "SetDeadline does not return error" )
205
+ assert .Equal (t , expected , tc .conn .(* mockConn ).readDeadline , "SetDeadline() sets read deadline" )
206
+ assert .Equal (t , expected , tc .conn .(* mockConn ).writeDeadline , "SetDeadline() sets write deadline" )
207
+ expected = time .Now ()
208
+ assert .Nil (t , tc .SetWriteDeadline (expected ), "SetWriteDeadline does not return error" )
209
+ assert .Equal (t , expected , tc .conn .(* mockConn ).writeDeadline , "Expected write deadline set" )
210
+ expected = time .Now ()
211
+ assert .Nil (t , tc .SetReadDeadline (expected ), "SetReadDeadline does not return error" )
212
+ assert .Equal (t , expected , tc .conn .(* mockConn ).readDeadline , "Expected read deadline set" )
213
+ expectedErr := fmt .Errorf ("deadline error" )
214
+ tc = & tunnelingWebsocketUpgraderConn {conn : & mockConn {deadlineErr : expectedErr }}
215
+ expected = time .Now ()
216
+ actualErr := tc .SetDeadline (expected )
217
+ assert .Equal (t , expectedErr , actualErr , "SetDeadline() expected error returned" )
218
+ // Connection nil, returns nil error.
219
+ tc .conn = nil
220
+ assert .Nil (t , tc .SetDeadline (expected ), "SetDeadline() with nil connection always returns nil error" )
221
+ assert .Nil (t , tc .SetWriteDeadline (expected ), "SetWriteDeadline() with nil connection always returns nil error" )
222
+ assert .Nil (t , tc .SetReadDeadline (expected ), "SetReadDeadline() with nil connection always returns nil error" )
223
+ }
224
+
110
225
var expectedContentLengthHeaders = http.Header {
111
226
"Content-Length" : []string {"25" },
112
227
"Date" : []string {"Sun, 25 Feb 2024 08:09:25 GMT" },
@@ -330,21 +445,44 @@ func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, ba
330
445
var _ net.Conn = & mockConn {}
331
446
332
447
type mockConn struct {
333
- written []byte
448
+ written []byte
449
+ localAddr * net.TCPAddr
450
+ remoteAddr * net.TCPAddr
451
+ readDeadline time.Time
452
+ writeDeadline time.Time
453
+ deadlineErr error
334
454
}
335
455
336
456
func (mc * mockConn ) Write (p []byte ) (int , error ) {
337
457
mc .written = append (mc .written , p ... )
338
458
return len (p ), nil
339
459
}
340
460
341
- func (mc * mockConn ) Read (p []byte ) (int , error ) { return 0 , nil }
342
- func (mc * mockConn ) Close () error { return nil }
343
- func (mc * mockConn ) LocalAddr () net.Addr { return & net.TCPAddr {} }
344
- func (mc * mockConn ) RemoteAddr () net.Addr { return & net.TCPAddr {} }
345
- func (mc * mockConn ) SetDeadline (t time.Time ) error { return nil }
346
- func (mc * mockConn ) SetReadDeadline (t time.Time ) error { return nil }
347
- func (mc * mockConn ) SetWriteDeadline (t time.Time ) error { return nil }
461
+ func (mc * mockConn ) Read (p []byte ) (int , error ) { return 0 , nil }
462
+ func (mc * mockConn ) Close () error { return nil }
463
+ func (mc * mockConn ) LocalAddr () net.Addr { return mc .localAddr }
464
+ func (mc * mockConn ) RemoteAddr () net.Addr { return mc .remoteAddr }
465
+ func (mc * mockConn ) SetDeadline (t time.Time ) error {
466
+ mc .SetReadDeadline (t ) //nolint:errcheck
467
+ mc .SetWriteDeadline (t ) // nolint:errcheck
468
+ return mc .deadlineErr
469
+ }
470
+ func (mc * mockConn ) SetReadDeadline (t time.Time ) error { mc .readDeadline = t ; return mc .deadlineErr }
471
+ func (mc * mockConn ) SetWriteDeadline (t time.Time ) error { mc .writeDeadline = t ; return mc .deadlineErr }
472
+
473
+ // mockResponseWriter implements "http.ResponseWriter" interface
474
+ type mockResponseWriter struct {
475
+ header http.Header
476
+ written []byte
477
+ statusCode int
478
+ }
479
+
480
+ func (mrw * mockResponseWriter ) Header () http.Header { return mrw .header }
481
+ func (mrw * mockResponseWriter ) Write (p []byte ) (int , error ) {
482
+ mrw .written = append (mrw .written , p ... )
483
+ return len (p ), nil
484
+ }
485
+ func (mrw * mockResponseWriter ) WriteHeader (statusCode int ) { mrw .statusCode = statusCode }
348
486
349
487
// fakeResponder implements "rest.Responder" interface.
350
488
var _ rest.Responder = & fakeResponder {}
0 commit comments