Skip to content

Commit e166c44

Browse files
Merge pull request #123785 from seans3/streamtunnel-unit-tests
Adds unit tests to `PortForward` streamtunnel Kubernetes-commit: 065a0f2d5116cc6f66eb8a0c6296e05b90838ec8
2 parents ec72042 + 5e1f756 commit e166c44

File tree

1 file changed

+147
-9
lines changed

1 file changed

+147
-9
lines changed

pkg/util/proxy/streamtunnel_test.go

Lines changed: 147 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package proxy
1919
import (
2020
"bytes"
2121
"crypto/rand"
22+
"errors"
23+
"fmt"
2224
"io"
2325
"net"
2426
"net/http"
@@ -48,7 +50,6 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
4850
defer close(streamChan)
4951
stopServerChan := make(chan struct{})
5052
defer close(stopServerChan)
51-
// Create fake upstream SPDY server.
5253
spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
5354
_, err := httpstream.Handshake(req, w, []string{constants.PortForwardV1Name})
5455
require.NoError(t, err)
@@ -107,6 +108,120 @@ func TestTunnelingHandler_UpgradeStreamingAndTunneling(t *testing.T) {
107108
assert.Equal(t, randomData, actual, "error validating tunneled random data")
108109
}
109110

111+
func TestTunnelingResponseWriter_Hijack(t *testing.T) {
112+
// Regular hijack returns connection, nil bufio, and no error.
113+
trw := &tunnelingResponseWriter{conn: &mockConn{}}
114+
assert.False(t, trw.hijacked, "hijacked field starts false before Hijack()")
115+
assert.False(t, trw.written, "written field startes false before Hijack()")
116+
actual, bufio, err := trw.Hijack()
117+
assert.NoError(t, err, "Hijack() does not return error")
118+
assert.NotNil(t, actual, "conn returned from Hijack() is not nil")
119+
assert.Nil(t, bufio, "bufio returned from Hijack() is always nil")
120+
assert.True(t, trw.hijacked, "hijacked field becomes true after Hijack()")
121+
assert.False(t, trw.written, "written field stays false after Hijack()")
122+
// Hijacking after writing to response writer is an error.
123+
trw = &tunnelingResponseWriter{written: true}
124+
_, _, err = trw.Hijack()
125+
assert.Error(t, err, "Hijack after writing to response writer is error")
126+
assert.True(t, strings.Contains(err.Error(), "connection has already been written to"))
127+
// Hijacking after already hijacked is an error.
128+
trw = &tunnelingResponseWriter{hijacked: true}
129+
_, _, err = trw.Hijack()
130+
assert.Error(t, err, "Hijack after writing to response writer is error")
131+
assert.True(t, strings.Contains(err.Error(), "connection has already been hijacked"))
132+
}
133+
134+
func TestTunnelingResponseWriter_DelegateResponseWriter(t *testing.T) {
135+
// Validate Header() for delegate response writer.
136+
expectedHeader := http.Header{}
137+
expectedHeader.Set("foo", "bar")
138+
trw := &tunnelingResponseWriter{w: &mockResponseWriter{header: expectedHeader}}
139+
assert.Equal(t, expectedHeader, trw.Header(), "")
140+
// Validate Write() for delegate response writer.
141+
expectedWrite := []byte("this is a test write string")
142+
assert.False(t, trw.written, "written field is before Write()")
143+
_, err := trw.Write(expectedWrite)
144+
assert.NoError(t, err, "No error expected after Write() on tunneling response writer")
145+
assert.True(t, trw.written, "written field is set after writing to tunneling response writer")
146+
// Writing to response writer after hijacked is an error.
147+
trw.hijacked = true
148+
_, err = trw.Write(expectedWrite)
149+
assert.Error(t, err, "Writing to ResponseWriter after Hijack() is an error")
150+
assert.True(t, errors.Is(err, http.ErrHijacked), "Hijacked error returned if writing after hijacked")
151+
// Validate WriteHeader().
152+
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
153+
expectedStatusCode := 201
154+
assert.False(t, trw.written, "Written field originally false in delegate response writer")
155+
trw.WriteHeader(expectedStatusCode)
156+
assert.Equal(t, expectedStatusCode, trw.w.(*mockResponseWriter).statusCode, "Expected written status code is correct")
157+
assert.True(t, trw.written, "Written field set to true after writing delegate response writer")
158+
// Response writer already written to does not write status.
159+
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
160+
trw.written = true
161+
trw.WriteHeader(expectedStatusCode)
162+
assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code for previously written response writer")
163+
// Hijacked response writer does not write status.
164+
trw = &tunnelingResponseWriter{w: &mockResponseWriter{}}
165+
trw.hijacked = true
166+
trw.WriteHeader(expectedStatusCode)
167+
assert.Equal(t, 0, trw.w.(*mockResponseWriter).statusCode, "No status code written to hijacked response writer")
168+
assert.False(t, trw.written, "Hijacked response writer does not write status")
169+
// Writing "101 Switching Protocols" status is an error, since it should happen via hijacked connection.
170+
trw = &tunnelingResponseWriter{w: &mockResponseWriter{header: http.Header{}}}
171+
trw.WriteHeader(http.StatusSwitchingProtocols)
172+
assert.Equal(t, http.StatusInternalServerError, trw.w.(*mockResponseWriter).statusCode, "Internal server error written")
173+
}
174+
175+
func TestTunnelingWebsocketUpgraderConn_LocalRemoteAddress(t *testing.T) {
176+
expectedLocalAddr := &net.TCPAddr{
177+
IP: net.IPv4(127, 0, 0, 1),
178+
Port: 80,
179+
}
180+
expectedRemoteAddr := &net.TCPAddr{
181+
IP: net.IPv4(127, 0, 0, 2),
182+
Port: 443,
183+
}
184+
tc := &tunnelingWebsocketUpgraderConn{
185+
conn: &mockConn{
186+
localAddr: expectedLocalAddr,
187+
remoteAddr: expectedRemoteAddr,
188+
},
189+
}
190+
assert.Equal(t, expectedLocalAddr, tc.LocalAddr(), "LocalAddr() returns expected TCPAddr")
191+
assert.Equal(t, expectedRemoteAddr, tc.RemoteAddr(), "RemoteAddr() returns expected TCPAddr")
192+
// Connection nil, returns empty address
193+
tc.conn = nil
194+
assert.Equal(t, noopAddr{}, tc.LocalAddr(), "nil connection, LocalAddr() returns noopAddr")
195+
assert.Equal(t, noopAddr{}, tc.RemoteAddr(), "nil connection, RemoteAddr() returns noopAddr")
196+
// Validate the empty strings from noopAddr
197+
assert.Equal(t, "", noopAddr{}.Network(), "noopAddr Network() returns empty string")
198+
assert.Equal(t, "", noopAddr{}.String(), "noopAddr String() returns empty string")
199+
}
200+
201+
func TestTunnelingWebsocketUpgraderConn_SetDeadline(t *testing.T) {
202+
tc := &tunnelingWebsocketUpgraderConn{conn: &mockConn{}}
203+
expected := time.Now()
204+
assert.Nil(t, tc.SetDeadline(expected), "SetDeadline does not return error")
205+
assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "SetDeadline() sets read deadline")
206+
assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "SetDeadline() sets write deadline")
207+
expected = time.Now()
208+
assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline does not return error")
209+
assert.Equal(t, expected, tc.conn.(*mockConn).writeDeadline, "Expected write deadline set")
210+
expected = time.Now()
211+
assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline does not return error")
212+
assert.Equal(t, expected, tc.conn.(*mockConn).readDeadline, "Expected read deadline set")
213+
expectedErr := fmt.Errorf("deadline error")
214+
tc = &tunnelingWebsocketUpgraderConn{conn: &mockConn{deadlineErr: expectedErr}}
215+
expected = time.Now()
216+
actualErr := tc.SetDeadline(expected)
217+
assert.Equal(t, expectedErr, actualErr, "SetDeadline() expected error returned")
218+
// Connection nil, returns nil error.
219+
tc.conn = nil
220+
assert.Nil(t, tc.SetDeadline(expected), "SetDeadline() with nil connection always returns nil error")
221+
assert.Nil(t, tc.SetWriteDeadline(expected), "SetWriteDeadline() with nil connection always returns nil error")
222+
assert.Nil(t, tc.SetReadDeadline(expected), "SetReadDeadline() with nil connection always returns nil error")
223+
}
224+
110225
var expectedContentLengthHeaders = http.Header{
111226
"Content-Length": []string{"25"},
112227
"Date": []string{"Sun, 25 Feb 2024 08:09:25 GMT"},
@@ -330,21 +445,44 @@ func (m *mockConnInitializer) InitializeWrite(backendResponse *http.Response, ba
330445
var _ net.Conn = &mockConn{}
331446

332447
type mockConn struct {
333-
written []byte
448+
written []byte
449+
localAddr *net.TCPAddr
450+
remoteAddr *net.TCPAddr
451+
readDeadline time.Time
452+
writeDeadline time.Time
453+
deadlineErr error
334454
}
335455

336456
func (mc *mockConn) Write(p []byte) (int, error) {
337457
mc.written = append(mc.written, p...)
338458
return len(p), nil
339459
}
340460

341-
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
342-
func (mc *mockConn) Close() error { return nil }
343-
func (mc *mockConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
344-
func (mc *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
345-
func (mc *mockConn) SetDeadline(t time.Time) error { return nil }
346-
func (mc *mockConn) SetReadDeadline(t time.Time) error { return nil }
347-
func (mc *mockConn) SetWriteDeadline(t time.Time) error { return nil }
461+
func (mc *mockConn) Read(p []byte) (int, error) { return 0, nil }
462+
func (mc *mockConn) Close() error { return nil }
463+
func (mc *mockConn) LocalAddr() net.Addr { return mc.localAddr }
464+
func (mc *mockConn) RemoteAddr() net.Addr { return mc.remoteAddr }
465+
func (mc *mockConn) SetDeadline(t time.Time) error {
466+
mc.SetReadDeadline(t) //nolint:errcheck
467+
mc.SetWriteDeadline(t) // nolint:errcheck
468+
return mc.deadlineErr
469+
}
470+
func (mc *mockConn) SetReadDeadline(t time.Time) error { mc.readDeadline = t; return mc.deadlineErr }
471+
func (mc *mockConn) SetWriteDeadline(t time.Time) error { mc.writeDeadline = t; return mc.deadlineErr }
472+
473+
// mockResponseWriter implements "http.ResponseWriter" interface
474+
type mockResponseWriter struct {
475+
header http.Header
476+
written []byte
477+
statusCode int
478+
}
479+
480+
func (mrw *mockResponseWriter) Header() http.Header { return mrw.header }
481+
func (mrw *mockResponseWriter) Write(p []byte) (int, error) {
482+
mrw.written = append(mrw.written, p...)
483+
return len(p), nil
484+
}
485+
func (mrw *mockResponseWriter) WriteHeader(statusCode int) { mrw.statusCode = statusCode }
348486

349487
// fakeResponder implements "rest.Responder" interface.
350488
var _ rest.Responder = &fakeResponder{}

0 commit comments

Comments
 (0)