Skip to content

Commit cce6057

Browse files
author
Divjot Arora
authored
GODRIVER-1489 Add ability to stream wire messages (#405)
1 parent b0ac484 commit cce6057

File tree

6 files changed

+188
-22
lines changed

6 files changed

+188
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ driver-test-data.tar.gz
1111
perf
1212
**mongocryptd.pid
1313
*.test
14+
**.md

x/mongo/driver/driver.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,21 @@ type Expirable interface {
6565
Alive() bool
6666
}
6767

68+
// StreamerConnection represents a Connection that supports streaming wire protocol messages using the moreToCome and
69+
// exhaustAllowed flags.
70+
//
71+
// The SetStreaming and CurrentlyStreaming functions correspond to the moreToCome flag on server responses. If a
72+
// response has moreToCome set, SetStreaming(true) will be called and CurrentlyStreaming() should return true.
73+
//
74+
// CanStream corresponds to the exhaustAllowed flag. The operations layer will set exhaustAllowed on outgoing wire
75+
// messages to inform the server that the driver supports streaming.
76+
type StreamerConnection interface {
77+
Connection
78+
SetStreaming(bool)
79+
CurrentlyStreaming() bool
80+
SupportsStreaming() bool
81+
}
82+
6883
// Compressor is an interface used to compress wire messages. If a Connection supports compression
6984
// it should implement this interface as well. The CompressWireMessage method will be called during
7085
// the execution of an operation if the wire message is allowed to be compressed.
@@ -127,20 +142,10 @@ func (ssd SingleConnectionDeployment) SupportsRetryWrites() bool { return false
127142
func (ssd SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single }
128143

129144
// Connection implements the Server interface. It always returns the embedded connection.
130-
//
131-
// This method returns a Connection with a no-op Close method. This ensures that a
132-
// SingleConnectionDeployment can be used across multiple operation executions.
133145
func (ssd SingleConnectionDeployment) Connection(context.Context) (Connection, error) {
134-
return nopCloserConnection{ssd.C}, nil
146+
return ssd.C, nil
135147
}
136148

137-
// nopCloserConnection is an adapter used in a SingleConnectionDeployment. It passes through all
138-
// functionality expcect for closing, which is a no-op. This is done so the connection can be used
139-
// across multiple operations.
140-
type nopCloserConnection struct{ Connection }
141-
142-
func (ncc nopCloserConnection) Close() error { return nil }
143-
144149
// TODO(GODRIVER-617): We can likely use 1 type for both the Type and the RetryMode by using
145150
// 2 bits for the mode and 1 bit for the type. Although in the practical sense, we might not want to
146151
// do that since the type of retryability is tied to the operation itself and isn't going change,

x/mongo/driver/operation.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
329329
if len(scratch) > 0 {
330330
scratch = scratch[:0]
331331
}
332-
wm, startedInfo, err := op.createWireMessage(ctx, scratch, desc)
332+
wm, startedInfo, err := op.createWireMessage(ctx, scratch, desc, conn)
333333
if err != nil {
334334
return err
335335
}
@@ -575,6 +575,12 @@ func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) (
575575
return nil, Error{Message: err.Error(), Labels: labels, Wrapped: err}
576576
}
577577

578+
return op.readWireMessage(ctx, conn, wm)
579+
}
580+
581+
func (op Operation) readWireMessage(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
582+
var err error
583+
578584
wm, err = conn.ReadWireMessage(ctx, wm[:0])
579585
if err != nil {
580586
labels := []string{NetworkError}
@@ -590,6 +596,12 @@ func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) (
590596
return nil, Error{Message: err.Error(), Labels: labels, Wrapped: err}
591597
}
592598

599+
// If we're using a streamable connection, we set its streaming state based on the moreToCome flag in the server
600+
// response.
601+
if streamer, ok := conn.(StreamerConnection); ok {
602+
streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm))
603+
}
604+
593605
// decompress wiremessage
594606
wm, err = op.decompressWireMessage(wm)
595607
if err != nil {
@@ -675,12 +687,12 @@ func (Operation) decompressWireMessage(wm []byte) ([]byte, error) {
675687
}
676688

677689
func (op Operation) createWireMessage(ctx context.Context, dst []byte,
678-
desc description.SelectedServer) ([]byte, startedInformation, error) {
690+
desc description.SelectedServer, conn Connection) ([]byte, startedInformation, error) {
679691

680692
if desc.WireVersion == nil || desc.WireVersion.Max < wiremessage.OpmsgWireVersion {
681693
return op.createQueryWireMessage(dst, desc)
682694
}
683-
return op.createMsgWireMessage(ctx, dst, desc)
695+
return op.createMsgWireMessage(ctx, dst, desc, conn)
684696
}
685697

686698
func (op Operation) addBatchArray(dst []byte) []byte {
@@ -758,7 +770,9 @@ func (op Operation) createQueryWireMessage(dst []byte, desc description.Selected
758770
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
759771
}
760772

761-
func (op Operation) createMsgWireMessage(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, startedInformation, error) {
773+
func (op Operation) createMsgWireMessage(ctx context.Context, dst []byte, desc description.SelectedServer,
774+
conn Connection) ([]byte, startedInformation, error) {
775+
762776
var info startedInformation
763777
var flags wiremessage.MsgFlag
764778
var wmindex int32
@@ -767,6 +781,12 @@ func (op Operation) createMsgWireMessage(ctx context.Context, dst []byte, desc d
767781
if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && (op.Batches == nil || len(op.Batches.Documents) == 0) {
768782
flags = wiremessage.MoreToCome
769783
}
784+
// Set the ExhaustAllowed flag if the connection supports streaming. This will tell the server that it can
785+
// respond with the MoreToCome flag and then stream responses over this connection.
786+
if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
787+
flags |= wiremessage.ExhaustAllowed
788+
}
789+
770790
info.requestID = wiremessage.NextRequestID()
771791
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
772792
dst = wiremessage.AppendMsgFlags(dst, flags)

x/mongo/driver/operation_exhaust.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (C) MongoDB, Inc. 2017-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package driver
8+
9+
import (
10+
"context"
11+
"errors"
12+
13+
"go.mongodb.org/mongo-driver/x/mongo/driver/description"
14+
)
15+
16+
// ExecuteExhaust reads a response from the provided StreamerConnection. This will error if the connection's
17+
// CurrentlyStreaming function returns false.
18+
func (op Operation) ExecuteExhaust(ctx context.Context, conn StreamerConnection, scratch []byte) error {
19+
if !conn.CurrentlyStreaming() {
20+
return errors.New("exhaust read must be done with a connection that is currently streaming")
21+
}
22+
23+
scratch = scratch[:0]
24+
res, err := op.readWireMessage(ctx, conn, scratch)
25+
if err != nil {
26+
return err
27+
}
28+
if op.ProcessResponseFn != nil {
29+
if err = op.ProcessResponseFn(res, nil, description.Server{}); err != nil {
30+
return err
31+
}
32+
}
33+
34+
return nil
35+
}

x/mongo/driver/operation_test.go

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/google/go-cmp/cmp"
1111
"go.mongodb.org/mongo-driver/bson/bsontype"
1212
"go.mongodb.org/mongo-driver/bson/primitive"
13+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
1314
"go.mongodb.org/mongo-driver/mongo/readconcern"
1415
"go.mongodb.org/mongo-driver/mongo/readpref"
1516
"go.mongodb.org/mongo-driver/mongo/writeconcern"
@@ -518,6 +519,93 @@ func TestOperation(t *testing.T) {
518519
})
519520
}
520521
})
522+
t.Run("ExecuteExhaust", func(t *testing.T) {
523+
t.Run("errors if connection is not streaming", func(t *testing.T) {
524+
conn := &mockConnection{
525+
rStreaming: false,
526+
}
527+
err := Operation{}.ExecuteExhaust(context.TODO(), conn, nil)
528+
assert.NotNil(t, err, "expected error, got nil")
529+
})
530+
})
531+
t.Run("exhaustAllowed and moreToCome", func(t *testing.T) {
532+
// Test the interaction between exhaustAllowed and moreToCome on requests/responses when using the Execute
533+
// and ExecuteExhaust methods.
534+
535+
// Create a server response wire message that has moreToCome=false.
536+
serverResponseDoc := bsoncore.BuildDocumentFromElements(nil,
537+
bsoncore.AppendInt32Element(nil, "ok", 1),
538+
)
539+
nonStreamingResponse := createExhaustServerResponse(t, serverResponseDoc, false)
540+
541+
// Create a connection that reports that it cannot stream messages.
542+
conn := &mockConnection{
543+
rDesc: description.Server{
544+
WireVersion: &description.VersionRange{
545+
Max: 6,
546+
},
547+
},
548+
rReadWM: nonStreamingResponse,
549+
rCanStream: false,
550+
}
551+
op := Operation{
552+
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
553+
return bsoncore.AppendInt32Element(dst, "isMaster", 1), nil
554+
},
555+
Database: "admin",
556+
Deployment: SingleConnectionDeployment{conn},
557+
}
558+
err := op.Execute(context.TODO(), nil)
559+
assert.Nil(t, err, "Execute error: %v", err)
560+
561+
// The wire message sent to the server should not have exhaustAllowed=true. After execution, the connection
562+
// should not be in a streaming state.
563+
assertExhaustAllowedSet(t, conn.pWriteWM, false)
564+
assert.False(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be false")
565+
566+
// Modify the connection to report that it can stream and create a new server response with moreToCome=true.
567+
streamingResponse := createExhaustServerResponse(t, serverResponseDoc, true)
568+
conn.rReadWM = streamingResponse
569+
conn.rCanStream = true
570+
err = op.Execute(context.TODO(), nil)
571+
assert.Nil(t, err, "Execute error: %v", err)
572+
assertExhaustAllowedSet(t, conn.pWriteWM, true)
573+
assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
574+
575+
// Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After
576+
// execution, the connection should still be in a streaming state.
577+
conn.rReadWM = streamingResponse
578+
err = op.ExecuteExhaust(context.TODO(), conn, nil)
579+
assert.Nil(t, err, "ExecuteExhaust error: %v", err)
580+
assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
581+
})
582+
}
583+
584+
func createExhaustServerResponse(t *testing.T, response bsoncore.Document, moreToCome bool) []byte {
585+
idx, wm := wiremessage.AppendHeaderStart(nil, 0, wiremessage.CurrentRequestID()+1, wiremessage.OpMsg)
586+
var flags wiremessage.MsgFlag
587+
if moreToCome {
588+
flags = wiremessage.MoreToCome
589+
}
590+
wm = wiremessage.AppendMsgFlags(wm, flags)
591+
wm = wiremessage.AppendMsgSectionType(wm, wiremessage.SingleDocument)
592+
wm = bsoncore.AppendDocument(wm, response)
593+
return bsoncore.UpdateLength(wm, idx, int32(len(wm)))
594+
}
595+
596+
func assertExhaustAllowedSet(t *testing.T, wm []byte, expected bool) {
597+
t.Helper()
598+
_, _, _, _, wm, ok := wiremessage.ReadHeader(wm)
599+
if !ok {
600+
t.Fatal("could not read wm header")
601+
}
602+
flags, wm, ok := wiremessage.ReadMsgFlags(wm)
603+
if !ok {
604+
t.Fatal("could not read wm flags")
605+
}
606+
607+
actual := flags&wiremessage.ExhaustAllowed > 0
608+
assert.Equal(t, expected, actual, "expected exhaustAllowed set %v, got %v", expected, actual)
521609
}
522610

523611
type mockDeployment struct {
@@ -554,19 +642,24 @@ type mockConnection struct {
554642
pReadDst []byte
555643

556644
// returns
557-
rWriteErr error
558-
rReadWM []byte
559-
rReadErr error
560-
rDesc description.Server
561-
rCloseErr error
562-
rID string
563-
rAddr address.Address
645+
rWriteErr error
646+
rReadWM []byte
647+
rReadErr error
648+
rDesc description.Server
649+
rCloseErr error
650+
rID string
651+
rAddr address.Address
652+
rCanStream bool
653+
rStreaming bool
564654
}
565655

566656
func (m *mockConnection) Description() description.Server { return m.rDesc }
567657
func (m *mockConnection) Close() error { return m.rCloseErr }
568658
func (m *mockConnection) ID() string { return m.rID }
569659
func (m *mockConnection) Address() address.Address { return m.rAddr }
660+
func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream }
661+
func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming }
662+
func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming }
570663

571664
func (m *mockConnection) WriteWireMessage(_ context.Context, wm []byte) error {
572665
m.pWriteWM = wm

x/mongo/driver/topology/connection.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ type connection struct {
4949
config *connectionConfig
5050
cancelConnectContext context.CancelFunc
5151
connectContextMade chan struct{}
52+
canStream bool
53+
currentlyStreaming bool
5254

5355
// pool related fields
5456
pool *pool
@@ -339,6 +341,7 @@ func (c *connection) bumpIdleDeadline() {
339341
type initConnection struct{ *connection }
340342

341343
var _ driver.Connection = initConnection{}
344+
var _ driver.StreamerConnection = initConnection{}
342345

343346
func (c initConnection) Description() description.Server {
344347
if c.connection == nil {
@@ -361,6 +364,15 @@ func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
361364
func (c initConnection) ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error) {
362365
return c.readWireMessage(ctx, dst)
363366
}
367+
func (c initConnection) SetStreaming(streaming bool) {
368+
c.currentlyStreaming = streaming
369+
}
370+
func (c initConnection) CurrentlyStreaming() bool {
371+
return c.currentlyStreaming
372+
}
373+
func (c initConnection) SupportsStreaming() bool {
374+
return c.canStream
375+
}
364376

365377
// Connection implements the driver.Connection interface to allow reading and writing wire
366378
// messages and the driver.Expirable interface to allow expiring.

0 commit comments

Comments
 (0)