@@ -8,12 +8,14 @@ package topology
8
8
9
9
import (
10
10
"context"
11
+ "math"
11
12
"sync"
12
13
"sync/atomic"
13
14
"time"
14
15
15
16
"go.mongodb.org/mongo-driver/event"
16
17
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
18
+ "golang.org/x/sync/semaphore"
17
19
)
18
20
19
21
// ErrPoolConnected is returned from an attempt to connect an already connected pool
@@ -67,6 +69,7 @@ type pool struct {
67
69
connected int32 // Must be accessed using the sync/atomic package.
68
70
nextid uint64
69
71
opened map [uint64 ]* connection // opened holds all of the currently open connections.
72
+ sem * semaphore.Weighted
70
73
sync.Mutex
71
74
}
72
75
@@ -116,6 +119,7 @@ func connectionCloseFunc(v interface{}) {
116
119
return
117
120
}
118
121
122
+ _ = c .pool .removeConnection (c )
119
123
go func () {
120
124
_ = c .pool .closeConnection (c )
121
125
}()
@@ -141,16 +145,23 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
141
145
opts = append (opts , WithIdleTimeout (func (_ time.Duration ) time.Duration { return config .MaxIdleTime }))
142
146
}
143
147
148
+ var maxConns = config .MaxPoolSize
149
+ if maxConns == 0 {
150
+ maxConns = math .MaxInt64
151
+ }
152
+
144
153
pool := & pool {
145
154
address : config .Address ,
146
155
monitor : config .PoolMonitor ,
147
156
connected : disconnected ,
148
157
opened : make (map [uint64 ]* connection ),
149
158
opts : opts ,
159
+ sem : semaphore .NewWeighted (int64 (maxConns )),
150
160
}
151
161
152
162
// we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
153
163
rpc := resourcePoolConfig {
164
+ MaxSize : maxConns ,
154
165
MinSize : config .MinPoolSize ,
155
166
MaintainInterval : maintainInterval ,
156
167
ExpiredFn : connectionExpiredFunc ,
@@ -162,7 +173,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
162
173
pool .monitor .Event (& event.PoolEvent {
163
174
Type : event .PoolCreated ,
164
175
PoolOptions : & event.MonitorPoolOptions {
165
- MaxPoolSize : config . MaxPoolSize ,
176
+ MaxPoolSize : rpc . MaxSize ,
166
177
MinPoolSize : rpc .MinSize ,
167
178
WaitQueueTimeoutMS : uint64 (config .MaxIdleTime ) / uint64 (time .Millisecond ),
168
179
},
@@ -249,9 +260,11 @@ func (p *pool) disconnect(ctx context.Context) error {
249
260
Reason : event .ReasonPoolClosed ,
250
261
})
251
262
}
263
+ _ = p .removeConnection (pc )
252
264
_ = p .closeConnection (pc ) // We don't care about errors while closing the connection.
253
265
}
254
266
atomic .StoreInt32 (& p .connected , disconnected )
267
+ p .conns .clearTotal ()
255
268
256
269
if p .monitor != nil {
257
270
p .monitor .Event (& event.PoolEvent {
@@ -305,7 +318,6 @@ func (p *pool) makeNewConnection(ctx context.Context) (*connection, string, erro
305
318
306
319
// Checkout returns a connection from the pool
307
320
func (p * pool ) get (ctx context.Context ) (* connection , error ) {
308
-
309
321
if ctx == nil {
310
322
ctx = context .Background ()
311
323
}
@@ -321,81 +333,120 @@ func (p *pool) get(ctx context.Context) (*connection, error) {
321
333
return nil , ErrPoolDisconnected
322
334
}
323
335
324
- connVal := p .conns .Get ()
325
- if c , ok := connVal .(* connection ); ok && connVal != nil {
326
- // call connect if not connected
327
- if atomic .LoadInt32 (& c .connected ) == initialized {
328
- c .connect (ctx )
336
+ err := p .sem .Acquire (ctx , 1 )
337
+ if err != nil {
338
+ if p .monitor != nil {
339
+ p .monitor .Event (& event.PoolEvent {
340
+ Type : event .GetFailed ,
341
+ Address : p .address .String (),
342
+ Reason : event .ReasonTimedOut ,
343
+ })
329
344
}
345
+ return nil , ErrWaitQueueTimeout
346
+ }
330
347
331
- err := c .wait ()
332
- if err != nil {
348
+ // This loop is so that we don't end up with more than maxPoolSize connections if p.conns.Maintain runs between
349
+ // calling p.conns.Get() and making the new connection
350
+ for {
351
+ if atomic .LoadInt32 (& p .connected ) != connected {
333
352
if p .monitor != nil {
334
353
p .monitor .Event (& event.PoolEvent {
335
354
Type : event .GetFailed ,
336
355
Address : p .address .String (),
337
- Reason : event .ReasonConnectionErrored ,
356
+ Reason : event .ReasonPoolClosed ,
338
357
})
339
358
}
340
- return nil , err
359
+ p .sem .Release (1 )
360
+ return nil , ErrPoolDisconnected
341
361
}
342
362
343
- if p .monitor != nil {
344
- p .monitor .Event (& event.PoolEvent {
345
- Type : event .GetSucceeded ,
346
- Address : p .address .String (),
347
- ConnectionID : c .poolID ,
348
- })
349
- }
350
- return c , nil
351
- }
363
+ connVal := p .conns .Get ()
364
+ if c , ok := connVal .(* connection ); ok && connVal != nil {
365
+ // call connect if not connected
366
+ if atomic .LoadInt32 (& c .connected ) == initialized {
367
+ c .connect (ctx )
368
+ }
352
369
353
- select {
354
- case <- ctx .Done ():
355
- if p .monitor != nil {
356
- p .monitor .Event (& event.PoolEvent {
357
- Type : event .GetFailed ,
358
- Address : p .address .String (),
359
- Reason : event .ReasonTimedOut ,
360
- })
361
- }
362
- return nil , ctx .Err ()
363
- default :
364
- c , reason , err := p .makeNewConnection (ctx )
370
+ err := c .wait ()
371
+ if err != nil {
372
+ if p .monitor != nil {
373
+ p .monitor .Event (& event.PoolEvent {
374
+ Type : event .GetFailed ,
375
+ Address : p .address .String (),
376
+ Reason : event .ReasonConnectionErrored ,
377
+ })
378
+ }
379
+ p .conns .decrementTotal ()
380
+ p .sem .Release (1 )
381
+ return nil , err
382
+ }
365
383
366
- if err != nil {
367
384
if p .monitor != nil {
368
385
p .monitor .Event (& event.PoolEvent {
369
- Type : event .GetFailed ,
370
- Address : p .address .String (),
371
- Reason : reason ,
386
+ Type : event .GetSucceeded ,
387
+ Address : p .address .String (),
388
+ ConnectionID : c . poolID ,
372
389
})
373
390
}
374
- return nil , err
391
+ return c , nil
375
392
}
376
393
377
- c .connect (ctx )
378
- // wait for conn to be connected
379
- err = c .wait ()
380
- if err != nil {
394
+ select {
395
+ case <- ctx .Done ():
381
396
if p .monitor != nil {
382
397
p .monitor .Event (& event.PoolEvent {
383
398
Type : event .GetFailed ,
384
399
Address : p .address .String (),
385
- Reason : reason ,
400
+ Reason : event . ReasonTimedOut ,
386
401
})
387
402
}
388
- return nil , err
389
- }
403
+ p .sem .Release (1 )
404
+ return nil , ctx .Err ()
405
+ default :
406
+ made := p .conns .incrementTotal ()
407
+ if ! made {
408
+ continue
409
+ }
410
+ c , reason , err := p .makeNewConnection (ctx )
411
+
412
+ if err != nil {
413
+ if p .monitor != nil {
414
+ p .monitor .Event (& event.PoolEvent {
415
+ Type : event .GetFailed ,
416
+ Address : p .address .String (),
417
+ Reason : reason ,
418
+ })
419
+ }
420
+ p .sem .Release (1 )
421
+ p .conns .decrementTotal ()
422
+ return nil , err
423
+ }
390
424
391
- if p .monitor != nil {
392
- p .monitor .Event (& event.PoolEvent {
393
- Type : event .GetSucceeded ,
394
- Address : p .address .String (),
395
- ConnectionID : c .poolID ,
396
- })
425
+ c .connect (ctx )
426
+ // wait for conn to be connected
427
+ err = c .wait ()
428
+ if err != nil {
429
+ if p .monitor != nil {
430
+ p .monitor .Event (& event.PoolEvent {
431
+ Type : event .GetFailed ,
432
+ Address : p .address .String (),
433
+ Reason : reason ,
434
+ })
435
+ }
436
+ p .sem .Release (1 )
437
+ p .conns .decrementTotal ()
438
+ return nil , err
439
+ }
440
+
441
+ if p .monitor != nil {
442
+ p .monitor .Event (& event.PoolEvent {
443
+ Type : event .GetSucceeded ,
444
+ Address : p .address .String (),
445
+ ConnectionID : c .poolID ,
446
+ })
447
+ }
448
+ return c , nil
397
449
}
398
- return c , nil
399
450
}
400
451
}
401
452
@@ -405,9 +456,6 @@ func (p *pool) closeConnection(c *connection) error {
405
456
if c .pool != p {
406
457
return ErrWrongPool
407
458
}
408
- p .Lock ()
409
- delete (p .opened , c .poolID )
410
- p .Unlock ()
411
459
412
460
if atomic .LoadInt32 (& c .connected ) == connected {
413
461
c .closeConnectContext ()
@@ -441,8 +489,10 @@ func (p *pool) removeConnection(c *connection) error {
441
489
}
442
490
443
491
// put returns a connection to this pool. If the pool is connected, the connection is not
444
- // stale, and there is space in the cache, the connection is returned to the cache.
492
+ // stale, and there is space in the cache, the connection is returned to the cache. This
493
+ // assumes that the connection has already been counted in p.conns.totalSize.
445
494
func (p * pool ) put (c * connection ) error {
495
+ defer p .sem .Release (1 )
446
496
if p .monitor != nil {
447
497
var cid uint64
448
498
var addr string
0 commit comments