Skip to content

Commit ee70acf

Browse files
committed
remove packetWriter and simplify tests
1 parent 60ce788 commit ee70acf

File tree

7 files changed

+60
-94
lines changed

7 files changed

+60
-94
lines changed

compress.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func zCompress(src []byte, dst io.Writer) error {
8686
return nil
8787
}
8888

89-
type compressor struct {
89+
type decompressor struct {
9090
mc *mysqlConn
9191
// read buffer (FIFO).
9292
// We can not reuse already-read buffer until dropping Go 1.20 support.
@@ -95,13 +95,13 @@ type compressor struct {
9595
bytesBuf []byte
9696
}
9797

98-
func newCompressor(mc *mysqlConn) *compressor {
99-
return &compressor{
98+
func newDecompressor(mc *mysqlConn) *decompressor {
99+
return &decompressor{
100100
mc: mc,
101101
}
102102
}
103103

104-
func (c *compressor) readNext(need int) ([]byte, error) {
104+
func (c *decompressor) readNext(need int) ([]byte, error) {
105105
for len(c.bytesBuf) < need {
106106
if err := c.uncompressPacket(); err != nil {
107107
return nil, err
@@ -113,7 +113,7 @@ func (c *compressor) readNext(need int) ([]byte, error) {
113113
return data, nil
114114
}
115115

116-
func (c *compressor) uncompressPacket() error {
116+
func (c *decompressor) uncompressPacket() error {
117117
header, err := c.mc.buf.readNext(7) // size of compressed header
118118
if err != nil {
119119
return err
@@ -166,9 +166,11 @@ func (c *compressor) uncompressPacket() error {
166166

167167
const maxPayloadLen = maxPacketSize - 4
168168

169-
func (c *compressor) Write(data []byte) (int, error) {
170-
totalBytes := len(data)
171-
dataLen := len(data)
169+
// writeCompressed sends one or some packets with compression.
170+
// Use this instead of mc.netConn.Write() when mc.compress is true.
171+
func (mc *mysqlConn) writeCompressed(packets []byte) (int, error) {
172+
totalBytes := len(packets)
173+
dataLen := len(packets)
172174
blankHeader := make([]byte, 7)
173175
var buf bytes.Buffer
174176

@@ -177,7 +179,7 @@ func (c *compressor) Write(data []byte) (int, error) {
177179
if payloadLen > maxPayloadLen {
178180
payloadLen = maxPayloadLen
179181
}
180-
payload := data[:payloadLen]
182+
payload := packets[:payloadLen]
181183
uncompressedLen := payloadLen
182184

183185
if _, err := buf.Write(blankHeader); err != nil {
@@ -194,11 +196,11 @@ func (c *compressor) Write(data []byte) (int, error) {
194196
zCompress(payload, &buf)
195197
}
196198

197-
if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
199+
if err := mc.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
198200
return 0, err
199201
}
200202
dataLen -= payloadLen
201-
data = data[payloadLen:]
203+
packets = packets[payloadLen:]
202204
buf.Reset()
203205
}
204206

@@ -207,32 +209,32 @@ func (c *compressor) Write(data []byte) (int, error) {
207209

208210
// writeCompressedPacket writes a compressed packet with header.
209211
// data should start with 7 size space for header followed by payload.
210-
func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error {
212+
func (mc *mysqlConn) writeCompressedPacket(data []byte, uncompressedLen int) error {
211213
comprLength := len(data) - 7
212214
if debugTrace {
213-
c.mc.cfg.Logger.Print(
215+
mc.cfg.Logger.Print(
214216
fmt.Sprintf(
215217
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
216-
comprLength, uncompressedLen, c.mc.compressSequence))
218+
comprLength, uncompressedLen, mc.compressSequence))
217219
}
218220

219221
// compression header
220222
data[0] = byte(0xff & comprLength)
221223
data[1] = byte(0xff & (comprLength >> 8))
222224
data[2] = byte(0xff & (comprLength >> 16))
223225

224-
data[3] = c.mc.compressSequence
226+
data[3] = mc.compressSequence
225227

226228
// this value is never greater than maxPayloadLength
227229
data[4] = byte(0xff & uncompressedLen)
228230
data[5] = byte(0xff & (uncompressedLen >> 8))
229231
data[6] = byte(0xff & (uncompressedLen >> 16))
230232

231-
if _, err := c.mc.netConn.Write(data); err != nil {
232-
c.mc.cfg.Logger.Print(err)
233+
if _, err := mc.netConn.Write(data); err != nil {
234+
mc.cfg.Logger.Print(err)
233235
return err
234236
}
235237

236-
c.mc.compressSequence++
238+
mc.compressSequence++
237239
return nil
238240
}

compress_test.go

Lines changed: 28 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"crypto/rand"
1414
"fmt"
1515
"io"
16-
"net"
1716
"testing"
1817
)
1918

@@ -23,85 +22,37 @@ func makeRandByteSlice(size int) []byte {
2322
return randBytes
2423
}
2524

26-
func newMockConn() *mysqlConn {
27-
newConn := &mysqlConn{cfg: NewConfig()}
28-
return newConn
29-
}
30-
31-
func newMockBuf(data []byte) buffer {
32-
return buffer{
33-
buf: data,
34-
length: len(data),
35-
}
36-
}
37-
38-
type dummyConn struct {
39-
buf bytes.Buffer
40-
net.Conn
41-
}
42-
43-
func (c *dummyConn) Write(data []byte) (int, error) {
44-
return c.buf.Write(data)
45-
}
46-
4725
// compressHelper compresses uncompressedPacket and checks state variables
4826
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
49-
// get status variables
50-
51-
cs := mc.compressSequence
52-
53-
var b dummyConn
54-
mc.netConn = &b
55-
cw := newCompressor(mc)
56-
57-
n, err := cw.Write(uncompressedPacket)
27+
conn := new(mockConn)
28+
mc.netConn = conn
5829

30+
n, err := mc.writeCompressed(uncompressedPacket)
5931
if err != nil {
60-
t.Fatal(err.Error())
32+
t.Fatal(err)
6133
}
62-
6334
if n != len(uncompressedPacket) {
6435
t.Fatalf("expected to write %d bytes, wrote %d bytes", len(uncompressedPacket), n)
6536
}
66-
67-
if len(uncompressedPacket) > 0 {
68-
if mc.compressSequence != (cs + 1) {
69-
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence)
70-
}
71-
72-
} else {
73-
if mc.compressSequence != cs {
74-
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty write, expected %d and saw %d", cs, mc.compressSequence)
75-
}
76-
}
77-
78-
return b.buf.Bytes()
37+
return conn.written
7938
}
8039

8140
// uncompressHelper uncompresses compressedPacket and checks state variables
8241
func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expSize int) []byte {
83-
// get status variables
84-
cs := mc.compressSequence
85-
8642
// mocking out buf variable
87-
mc.buf = newMockBuf(compressedPacket)
88-
cr := newCompressor(mc)
43+
conn := new(mockConn)
44+
conn.data = compressedPacket
45+
mc.buf.nc = conn
46+
cr := newDecompressor(mc)
8947

9048
uncompressedPacket, err := cr.readNext(expSize)
9149
if err != nil {
9250
if err != io.EOF {
9351
t.Fatalf("non-nil/non-EOF error when reading contents: %s", err.Error())
9452
}
9553
}
96-
97-
if expSize > 0 {
98-
if mc.compressSequence != (cs + 1) {
99-
t.Fatalf("mc.compressionSequence updated incorrectly, expected %d and saw %d", (cs + 1), mc.compressSequence)
100-
}
101-
} else {
102-
if mc.compressSequence != cs {
103-
t.Fatalf("mc.compressionSequence updated incorrectly for case of empty read, expected %d and saw %d", cs, mc.compressSequence)
104-
}
54+
if len(uncompressedPacket) != expSize {
55+
t.Errorf("uncompressed size is unexpected. expected %d but got %d", expSize, len(uncompressedPacket))
10556
}
10657
return uncompressedPacket
10758
}
@@ -141,20 +92,33 @@ func TestRoundtrip(t *testing.T) {
14192
{uncompressed: makeRandByteSlice(32768),
14293
desc: "32768 rand bytes",
14394
},
144-
{uncompressed: makeRandByteSlice(33000),
145-
desc: "33000 rand bytes",
95+
{uncompressed: bytes.Repeat(makeRandByteSlice(100), 10000),
96+
desc: "100 rand * 10000 repeat bytes",
14697
},
14798
}
14899

149-
cSend := newMockConn()
150-
cReceive := newMockConn()
100+
_, cSend := newRWMockConn(0)
101+
cSend.compress = true
102+
_, cReceive := newRWMockConn(0)
103+
cReceive.compress = true
151104

152105
for _, test := range tests {
153106
s := fmt.Sprintf("Test roundtrip with %s", test.desc)
107+
cSend.resetSequenceNr()
108+
cReceive.resetSequenceNr()
154109

155110
uncompressed := roundtripHelper(t, cSend, cReceive, test.uncompressed)
156111
if !bytes.Equal(uncompressed, test.uncompressed) {
157112
t.Fatalf("%s: roundtrip failed", s)
158113
}
114+
115+
if cSend.sequence != cReceive.sequence {
116+
t.Errorf("inconsistent sequence number: send=%v recv=%v",
117+
cSend.sequence, cReceive.sequence)
118+
}
119+
if cSend.compressSequence != cReceive.compressSequence {
120+
t.Errorf("inconsistent compress sequence number: send=%v recv=%v",
121+
cSend.compressSequence, cReceive.compressSequence)
122+
}
159123
}
160124
}

connection.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ type mysqlConn struct {
2626
rawConn net.Conn // underlying connection when netConn is TLS connection.
2727
result mysqlResult // managed by clearResult() and handleOkPacket().
2828
packetReader packetReader
29-
packetWriter io.Writer
3029
cfg *Config
3130
connector *connector
3231
maxAllowedPacket int

connection_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ func TestPingMarkBadConnection(t *testing.T) {
169169
netConn: nc,
170170
buf: buf,
171171
packetReader: &buf,
172-
packetWriter: nc,
173172
maxAllowedPacket: defaultMaxAllowedPacket,
174173
}
175174

@@ -188,7 +187,6 @@ func TestPingErrInvalidConn(t *testing.T) {
188187
netConn: nc,
189188
buf: buf,
190189
packetReader: &buf,
191-
packetWriter: nc,
192190
maxAllowedPacket: defaultMaxAllowedPacket,
193191
closech: make(chan struct{}),
194192
cfg: NewConfig(),

connector.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
125125
mc.buf = newBuffer(mc.netConn)
126126
// packet reader and writer in handshake are never compressed
127127
mc.packetReader = &mc.buf
128-
mc.packetWriter = mc.netConn
129128

130129
// Set I/O timeouts
131130
mc.buf.timeout = mc.cfg.ReadTimeout
@@ -168,10 +167,9 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
168167
return nil, err
169168
}
170169

171-
if mc.compress {
172-
cmpr := newCompressor(mc)
173-
mc.packetReader = cmpr
174-
mc.packetWriter = cmpr
170+
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
171+
mc.compress = true
172+
mc.packetReader = newDecompressor(mc)
175173
}
176174
if mc.cfg.MaxAllowedPacket > 0 {
177175
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket

packets.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,15 @@ func (mc *mysqlConn) writePacket(data []byte) error {
131131
}
132132
}
133133

134-
n, err := mc.packetWriter.Write(data[:4+size])
134+
var (
135+
n int
136+
err error
137+
)
138+
if mc.compress {
139+
n, err = mc.writeCompressed(data[:4+size])
140+
} else {
141+
n, err = mc.netConn.Write(data[:4+size])
142+
}
135143
if err == nil && n == 4+size {
136144
mc.sequence++
137145
if size != maxPacketSize {
@@ -278,7 +286,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
278286
}
279287
if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
280288
clientFlags |= clientCompress
281-
mc.compress = true
282289
}
283290
// To enable TLS / SSL
284291
if mc.cfg.TLS != nil {
@@ -368,7 +375,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
368375
mc.rawConn = mc.netConn
369376
mc.netConn = tlsConn
370377
mc.buf.nc = tlsConn
371-
mc.packetWriter = mc.netConn
372378
}
373379

374380
// User [null terminated string]

packets_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
101101
mc := &mysqlConn{
102102
buf: buf,
103103
packetReader: &buf,
104-
packetWriter: conn,
105104
cfg: connector.cfg,
106105
connector: connector,
107106
netConn: conn,

0 commit comments

Comments
 (0)