Skip to content

Commit bbe8c4a

Browse files
author
iwysiu
authored
GODRIVER-1613 fix how pool handles MaxPoolSize (#413)
1 parent ceaedd5 commit bbe8c4a

File tree

7 files changed

+251
-104
lines changed

7 files changed

+251
-104
lines changed

internal/testutil/assert/assert.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
package assert
88

99
import (
10+
"errors"
1011
"reflect"
1112
"sync"
1213
"testing"
14+
"time"
1315

1416
"github.com/google/go-cmp/cmp"
1517
)
@@ -23,6 +25,7 @@ var errorCompareFn = func(e1, e2 error) bool {
2325
return e1.Error() == e2.Error()
2426
}
2527
var errorCompareOpts = cmp.Options{cmp.Comparer(errorCompareFn)}
28+
var ErrCallbackTimedOut = errors.New("callback timed out")
2629

2730
// RegisterOpts registers go-cmp options for a type. These options will be used when comparing two objects for equality.
2831
func RegisterOpts(t reflect.Type, opts ...cmp.Option) {
@@ -80,6 +83,29 @@ func NotNil(t testing.TB, obj interface{}, msg string, args ...interface{}) {
8083
}
8184
}
8285

86+
// RunWithTimeout runs the provided callback for a maximum of timeoutMS milliseconds. It returns the callback error
87+
// if the callback returned and ErrCallbackTimedOut if the timeout expired.
88+
func RunWithTimeout(callback func() error, timeout time.Duration) error {
89+
done := make(chan struct{})
90+
var err error
91+
fullCallback := func() {
92+
err = callback()
93+
done <- struct{}{}
94+
}
95+
96+
timer := time.NewTimer(timeout)
97+
defer timer.Stop()
98+
99+
go fullCallback()
100+
101+
select {
102+
case <-done:
103+
return err
104+
case <-timer.C:
105+
return ErrCallbackTimedOut
106+
}
107+
}
108+
83109
func getCmpOpts(obj interface{}) cmp.Options {
84110
opts, ok := cmpOpts.Load(reflect.TypeOf(obj))
85111
if ok {

x/mongo/driver/topology/connection.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,15 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e
297297
}
298298

299299
func (c *connection) close() error {
300-
if atomic.LoadInt32(&c.connected) != connected {
300+
if !atomic.CompareAndSwapInt32(&c.connected, connected, disconnected) {
301301
return nil
302302
}
303303

304304
var err error
305305
if c.nc != nil {
306306
err = c.nc.Close()
307307
}
308-
atomic.StoreInt32(&c.connected, disconnected)
309308

310-
if c.pool != nil {
311-
_ = c.pool.removeConnection(c)
312-
}
313309
return err
314310
}
315311

@@ -458,9 +454,7 @@ func (c *Connection) Close() error {
458454
if c.connection == nil {
459455
return nil
460456
}
461-
if c.s != nil {
462-
defer c.s.sem.Release(1)
463-
}
457+
464458
err := c.pool.put(c.connection)
465459
c.connection = nil
466460
return err
@@ -473,10 +467,9 @@ func (c *Connection) Expire() error {
473467
if c.connection == nil {
474468
return nil
475469
}
476-
if c.s != nil {
477-
c.s.sem.Release(1)
478-
}
479-
err := c.close()
470+
471+
_ = c.close()
472+
err := c.pool.put(c.connection)
480473
c.connection = nil
481474
return err
482475
}

x/mongo/driver/topology/pool.go

Lines changed: 105 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ package topology
88

99
import (
1010
"context"
11+
"math"
1112
"sync"
1213
"sync/atomic"
1314
"time"
1415

1516
"go.mongodb.org/mongo-driver/event"
1617
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
18+
"golang.org/x/sync/semaphore"
1719
)
1820

1921
// ErrPoolConnected is returned from an attempt to connect an already connected pool
@@ -67,6 +69,7 @@ type pool struct {
6769
connected int32 // Must be accessed using the sync/atomic package.
6870
nextid uint64
6971
opened map[uint64]*connection // opened holds all of the currently open connections.
72+
sem *semaphore.Weighted
7073
sync.Mutex
7174
}
7275

@@ -116,6 +119,7 @@ func connectionCloseFunc(v interface{}) {
116119
return
117120
}
118121

122+
_ = c.pool.removeConnection(c)
119123
go func() {
120124
_ = c.pool.closeConnection(c)
121125
}()
@@ -141,16 +145,23 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
141145
opts = append(opts, WithIdleTimeout(func(_ time.Duration) time.Duration { return config.MaxIdleTime }))
142146
}
143147

148+
var maxConns = config.MaxPoolSize
149+
if maxConns == 0 {
150+
maxConns = math.MaxInt64
151+
}
152+
144153
pool := &pool{
145154
address: config.Address,
146155
monitor: config.PoolMonitor,
147156
connected: disconnected,
148157
opened: make(map[uint64]*connection),
149158
opts: opts,
159+
sem: semaphore.NewWeighted(int64(maxConns)),
150160
}
151161

152162
// we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
153163
rpc := resourcePoolConfig{
164+
MaxSize: maxConns,
154165
MinSize: config.MinPoolSize,
155166
MaintainInterval: maintainInterval,
156167
ExpiredFn: connectionExpiredFunc,
@@ -162,7 +173,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
162173
pool.monitor.Event(&event.PoolEvent{
163174
Type: event.PoolCreated,
164175
PoolOptions: &event.MonitorPoolOptions{
165-
MaxPoolSize: config.MaxPoolSize,
176+
MaxPoolSize: rpc.MaxSize,
166177
MinPoolSize: rpc.MinSize,
167178
WaitQueueTimeoutMS: uint64(config.MaxIdleTime) / uint64(time.Millisecond),
168179
},
@@ -249,9 +260,11 @@ func (p *pool) disconnect(ctx context.Context) error {
249260
Reason: event.ReasonPoolClosed,
250261
})
251262
}
263+
_ = p.removeConnection(pc)
252264
_ = p.closeConnection(pc) // We don't care about errors while closing the connection.
253265
}
254266
atomic.StoreInt32(&p.connected, disconnected)
267+
p.conns.clearTotal()
255268

256269
if p.monitor != nil {
257270
p.monitor.Event(&event.PoolEvent{
@@ -305,7 +318,6 @@ func (p *pool) makeNewConnection(ctx context.Context) (*connection, string, erro
305318

306319
// Checkout returns a connection from the pool
307320
func (p *pool) get(ctx context.Context) (*connection, error) {
308-
309321
if ctx == nil {
310322
ctx = context.Background()
311323
}
@@ -321,81 +333,120 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
321333
return nil, ErrPoolDisconnected
322334
}
323335

324-
connVal := p.conns.Get()
325-
if c, ok := connVal.(*connection); ok && connVal != nil {
326-
// call connect if not connected
327-
if atomic.LoadInt32(&c.connected) == initialized {
328-
c.connect(ctx)
336+
err := p.sem.Acquire(ctx, 1)
337+
if err != nil {
338+
if p.monitor != nil {
339+
p.monitor.Event(&event.PoolEvent{
340+
Type: event.GetFailed,
341+
Address: p.address.String(),
342+
Reason: event.ReasonTimedOut,
343+
})
329344
}
345+
return nil, ErrWaitQueueTimeout
346+
}
330347

331-
err := c.wait()
332-
if err != nil {
348+
// This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between
349+
// calling p.conns.Get() and making the new connection
350+
for {
351+
if atomic.LoadInt32(&p.connected) != connected {
333352
if p.monitor != nil {
334353
p.monitor.Event(&event.PoolEvent{
335354
Type: event.GetFailed,
336355
Address: p.address.String(),
337-
Reason: event.ReasonConnectionErrored,
356+
Reason: event.ReasonPoolClosed,
338357
})
339358
}
340-
return nil, err
359+
p.sem.Release(1)
360+
return nil, ErrPoolDisconnected
341361
}
342362

343-
if p.monitor != nil {
344-
p.monitor.Event(&event.PoolEvent{
345-
Type: event.GetSucceeded,
346-
Address: p.address.String(),
347-
ConnectionID: c.poolID,
348-
})
349-
}
350-
return c, nil
351-
}
363+
connVal := p.conns.Get()
364+
if c, ok := connVal.(*connection); ok && connVal != nil {
365+
// call connect if not connected
366+
if atomic.LoadInt32(&c.connected) == initialized {
367+
c.connect(ctx)
368+
}
352369

353-
select {
354-
case <-ctx.Done():
355-
if p.monitor != nil {
356-
p.monitor.Event(&event.PoolEvent{
357-
Type: event.GetFailed,
358-
Address: p.address.String(),
359-
Reason: event.ReasonTimedOut,
360-
})
361-
}
362-
return nil, ctx.Err()
363-
default:
364-
c, reason, err := p.makeNewConnection(ctx)
370+
err := c.wait()
371+
if err != nil {
372+
if p.monitor != nil {
373+
p.monitor.Event(&event.PoolEvent{
374+
Type: event.GetFailed,
375+
Address: p.address.String(),
376+
Reason: event.ReasonConnectionErrored,
377+
})
378+
}
379+
p.conns.decrementTotal()
380+
p.sem.Release(1)
381+
return nil, err
382+
}
365383

366-
if err != nil {
367384
if p.monitor != nil {
368385
p.monitor.Event(&event.PoolEvent{
369-
Type: event.GetFailed,
370-
Address: p.address.String(),
371-
Reason: reason,
386+
Type: event.GetSucceeded,
387+
Address: p.address.String(),
388+
ConnectionID: c.poolID,
372389
})
373390
}
374-
return nil, err
391+
return c, nil
375392
}
376393

377-
c.connect(ctx)
378-
// wait for conn to be connected
379-
err = c.wait()
380-
if err != nil {
394+
select {
395+
case <-ctx.Done():
381396
if p.monitor != nil {
382397
p.monitor.Event(&event.PoolEvent{
383398
Type: event.GetFailed,
384399
Address: p.address.String(),
385-
Reason: reason,
400+
Reason: event.ReasonTimedOut,
386401
})
387402
}
388-
return nil, err
389-
}
403+
p.sem.Release(1)
404+
return nil, ctx.Err()
405+
default:
406+
made := p.conns.incrementTotal()
407+
if !made {
408+
continue
409+
}
410+
c, reason, err := p.makeNewConnection(ctx)
411+
412+
if err != nil {
413+
if p.monitor != nil {
414+
p.monitor.Event(&event.PoolEvent{
415+
Type: event.GetFailed,
416+
Address: p.address.String(),
417+
Reason: reason,
418+
})
419+
}
420+
p.sem.Release(1)
421+
p.conns.decrementTotal()
422+
return nil, err
423+
}
390424

391-
if p.monitor != nil {
392-
p.monitor.Event(&event.PoolEvent{
393-
Type: event.GetSucceeded,
394-
Address: p.address.String(),
395-
ConnectionID: c.poolID,
396-
})
425+
c.connect(ctx)
426+
// wait for conn to be connected
427+
err = c.wait()
428+
if err != nil {
429+
if p.monitor != nil {
430+
p.monitor.Event(&event.PoolEvent{
431+
Type: event.GetFailed,
432+
Address: p.address.String(),
433+
Reason: reason,
434+
})
435+
}
436+
p.sem.Release(1)
437+
p.conns.decrementTotal()
438+
return nil, err
439+
}
440+
441+
if p.monitor != nil {
442+
p.monitor.Event(&event.PoolEvent{
443+
Type: event.GetSucceeded,
444+
Address: p.address.String(),
445+
ConnectionID: c.poolID,
446+
})
447+
}
448+
return c, nil
397449
}
398-
return c, nil
399450
}
400451
}
401452

@@ -405,9 +456,6 @@ func (p *pool) closeConnection(c *connection) error {
405456
if c.pool != p {
406457
return ErrWrongPool
407458
}
408-
p.Lock()
409-
delete(p.opened, c.poolID)
410-
p.Unlock()
411459

412460
if atomic.LoadInt32(&c.connected) == connected {
413461
c.closeConnectContext()
@@ -441,8 +489,10 @@ func (p *pool) removeConnection(c *connection) error {
441489
}
442490

443491
// put returns a connection to this pool. If the pool is connected, the connection is not
444-
// stale, and there is space in the cache, the connection is returned to the cache.
492+
// stale, and there is space in the cache, the connection is returned to the cache. This
493+
// assumes that the connection has already been counted in p.conns.totalSize.
445494
func (p *pool) put(c *connection) error {
495+
defer p.sem.Release(1)
446496
if p.monitor != nil {
447497
var cid uint64
448498
var addr string

0 commit comments

Comments
 (0)