@@ -24,13 +24,16 @@ import (
24
24
"sync/atomic"
25
25
"time"
26
26
27
+ "github.com/mongodb/mongo-go-driver/bson"
27
28
"github.com/mongodb/mongo-go-driver/core/address"
28
29
"github.com/mongodb/mongo-go-driver/core/compressor"
29
30
"github.com/mongodb/mongo-go-driver/core/description"
31
+ "github.com/mongodb/mongo-go-driver/core/event"
30
32
"github.com/mongodb/mongo-go-driver/core/wiremessage"
31
33
)
32
34
33
35
var globalClientConnectionID uint64
36
+ var emptyDoc = bson .NewDocument ()
34
37
35
38
func nextClientConnectionID () uint64 {
36
39
return atomic .AddUint64 (& globalClientConnectionID , 1 )
@@ -89,10 +92,12 @@ type connection struct {
89
92
compressor compressor.Compressor // use for compressing messages
90
93
// server can compress response with any compressor supported by driver
91
94
compressorMap map [wiremessage.CompressorID ]compressor.Compressor
95
+ commandMap map [int64 ]* event.CommandMetadata // map for monitoring commands sent to server
92
96
dead bool
93
97
idleTimeout time.Duration
94
98
idleDeadline time.Time
95
99
lifetimeDeadline time.Time
100
+ cmdMonitor * event.CommandMonitor
96
101
readTimeout time.Duration
97
102
uncompressBuf []byte // buffer to uncompress messages
98
103
writeTimeout time.Duration
@@ -140,6 +145,7 @@ func New(ctx context.Context, addr address.Address, opts ...Option) (Connection,
140
145
conn : nc ,
141
146
compressBuf : make ([]byte , 256 ),
142
147
compressorMap : compressorMap ,
148
+ commandMap : make (map [int64 ]* event.CommandMetadata ),
143
149
addr : addr ,
144
150
idleTimeout : cfg .idleTimeout ,
145
151
lifetimeDeadline : lifetimeDeadline ,
@@ -180,6 +186,7 @@ func New(ctx context.Context, addr address.Address, opts ...Option) (Connection,
180
186
desc = & d
181
187
}
182
188
189
+ c .cmdMonitor = cfg .cmdMonitor // attach the command monitor later to avoid monitoring auth
183
190
return c , desc , nil
184
191
}
185
192
@@ -342,6 +349,223 @@ func (c *connection) uncompressMessage(compressed wiremessage.Compressed) ([]byt
342
349
return fullMessage , origHeader .OpCode , nil
343
350
}
344
351
352
+ func canMonitor (cmd string ) bool {
353
+ if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
354
+ cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
355
+ return false
356
+ }
357
+
358
+ return true
359
+ }
360
+
361
+ func (c * connection ) commandStartedEvent (wm wiremessage.WireMessage ) error {
362
+ if c .cmdMonitor == nil || c .cmdMonitor .Started == nil {
363
+ return nil
364
+ }
365
+
366
+ startedEvent := & event.CommandStartedEvent {
367
+ ConnectionID : c .id ,
368
+ }
369
+
370
+ var cmd * bson.Document
371
+ var err error
372
+
373
+ var acknowledged bool
374
+ switch converted := wm .(type ) {
375
+ case wiremessage.Query :
376
+ cmd , err = bson .ReadDocument ([]byte (converted .Query ))
377
+ if err != nil {
378
+ return err
379
+ }
380
+
381
+ acknowledged = converted .AcknowledgedWrite ()
382
+ startedEvent .DatabaseName = converted .FullCollectionName [:len (converted .FullCollectionName )- 5 ] // remove $.cmd
383
+ startedEvent .RequestID = int64 (converted .MsgHeader .RequestID )
384
+
385
+ cmdElem := cmd .ElementAt (0 )
386
+ if cmdElem .Key () == "$query" {
387
+ cmd = cmdElem .Value ().MutableDocument ()
388
+ }
389
+ case wiremessage.Msg :
390
+ cmd , err = converted .GetMainDocument ()
391
+ if err != nil {
392
+ return err
393
+ }
394
+
395
+ acknowledged = converted .AcknowledgedWrite ()
396
+ arr , identifier , err := converted .GetSequenceArray ()
397
+ if err != nil {
398
+ return err
399
+ }
400
+ if arr != nil {
401
+ cmd = cmd .Copy () // make copy to avoid changing original command
402
+ cmd .Append (bson .EC .Array (identifier , arr ))
403
+ }
404
+
405
+ dbVal , err := cmd .LookupErr ("$db" )
406
+ if err != nil {
407
+ return err
408
+ }
409
+
410
+ startedEvent .DatabaseName = dbVal .StringValue ()
411
+ startedEvent .RequestID = int64 (converted .MsgHeader .RequestID )
412
+ }
413
+
414
+ startedEvent .Command = cmd
415
+ startedEvent .CommandName = cmd .ElementAt (0 ).Key ()
416
+ if ! canMonitor (startedEvent .CommandName ) {
417
+ startedEvent .Command = emptyDoc
418
+ }
419
+
420
+ c .cmdMonitor .Started (startedEvent )
421
+
422
+ if ! acknowledged {
423
+ if c .cmdMonitor .Succeeded == nil {
424
+ return nil
425
+ }
426
+
427
+ // unack writes must provide a CommandSucceededEvent with an { ok: 1 } reply
428
+ finishedEvent := event.CommandFinishedEvent {
429
+ DurationNanos : 0 ,
430
+ CommandName : startedEvent .CommandName ,
431
+ RequestID : startedEvent .RequestID ,
432
+ ConnectionID : c .id ,
433
+ }
434
+
435
+ c .cmdMonitor .Succeeded (& event.CommandSucceededEvent {
436
+ CommandFinishedEvent : finishedEvent ,
437
+ Reply : bson .NewDocument (
438
+ bson .EC .Int32 ("ok" , 1 ),
439
+ ),
440
+ })
441
+
442
+ return nil
443
+ }
444
+
445
+ c .commandMap [startedEvent .RequestID ] = event .CreateMetadata (startedEvent .CommandName )
446
+ return nil
447
+ }
448
+
449
+ func processReply (reply * bson.Document ) (bool , string ) {
450
+ iter := reply .Iterator ()
451
+ var success bool
452
+ var errmsg string
453
+ var errCode int32
454
+
455
+ for iter .Next () {
456
+ elem := iter .Element ()
457
+ switch elem .Key () {
458
+ case "ok" :
459
+ switch elem .Value ().Type () {
460
+ case bson .TypeInt32 :
461
+ if elem .Value ().Int32 () == 1 {
462
+ success = true
463
+ }
464
+ case bson .TypeInt64 :
465
+ if elem .Value ().Int64 () == 1 {
466
+ success = true
467
+ }
468
+ case bson .TypeDouble :
469
+ if elem .Value ().Double () == 1 {
470
+ success = true
471
+ }
472
+ }
473
+ case "errmsg" :
474
+ if str , ok := elem .Value ().StringValueOK (); ok {
475
+ errmsg = str
476
+ }
477
+ case "code" :
478
+ if c , ok := elem .Value ().Int32OK (); ok {
479
+ errCode = c
480
+ }
481
+ }
482
+ }
483
+
484
+ if success {
485
+ return true , ""
486
+ }
487
+
488
+ fullErrMsg := fmt .Sprintf ("Error code %d: %s" , errCode , errmsg )
489
+ return false , fullErrMsg
490
+ }
491
+
492
+ func (c * connection ) commandFinishedEvent (wm wiremessage.WireMessage ) error {
493
+ if c .cmdMonitor == nil {
494
+ return nil
495
+ }
496
+
497
+ var reply * bson.Document
498
+ var requestID int64
499
+ var err error
500
+
501
+ switch converted := wm .(type ) {
502
+ case wiremessage.Reply :
503
+ requestID = int64 (converted .MsgHeader .ResponseTo )
504
+ reply , err = converted .GetMainDocument ()
505
+ case wiremessage.Msg :
506
+ requestID = int64 (converted .MsgHeader .ResponseTo )
507
+ reply , err = converted .GetMainDocument ()
508
+ }
509
+
510
+ if err != nil {
511
+ return err
512
+ }
513
+
514
+ cmdMetadata := c .commandMap [requestID ]
515
+ delete (c .commandMap , requestID )
516
+ success , errmsg := processReply (reply )
517
+
518
+ if (success && c .cmdMonitor .Succeeded == nil ) || (! success && c .cmdMonitor .Failed == nil ) {
519
+ return nil
520
+ }
521
+
522
+ finishedEvent := event.CommandFinishedEvent {
523
+ DurationNanos : cmdMetadata .TimeDifference (),
524
+ CommandName : cmdMetadata .Name ,
525
+ RequestID : requestID ,
526
+ ConnectionID : c .id ,
527
+ }
528
+
529
+ if success {
530
+ if ! canMonitor (finishedEvent .CommandName ) {
531
+ successEvent := & event.CommandSucceededEvent {
532
+ Reply : emptyDoc ,
533
+ CommandFinishedEvent : finishedEvent ,
534
+ }
535
+ c .cmdMonitor .Succeeded (successEvent )
536
+ return nil
537
+ }
538
+
539
+ // if response has type 1 document sequence, the sequence must be included as a BSON array in the event's reply.
540
+ if opmsg , ok := wm .(wiremessage.Msg ); ok {
541
+ arr , identifier , err := opmsg .GetSequenceArray ()
542
+ if err != nil {
543
+ return err
544
+ }
545
+ if arr != nil {
546
+ reply = reply .Copy () // make copy to avoid changing original command
547
+ reply .Append (bson .EC .Array (identifier , arr ))
548
+ }
549
+ }
550
+
551
+ successEvent := & event.CommandSucceededEvent {
552
+ Reply : reply ,
553
+ CommandFinishedEvent : finishedEvent ,
554
+ }
555
+
556
+ c .cmdMonitor .Succeeded (successEvent )
557
+ return nil
558
+ }
559
+
560
+ failureEvent := & event.CommandFailedEvent {
561
+ Failure : errmsg ,
562
+ CommandFinishedEvent : finishedEvent ,
563
+ }
564
+
565
+ c .cmdMonitor .Failed (failureEvent )
566
+ return nil
567
+ }
568
+
345
569
func (c * connection ) WriteWireMessage (ctx context.Context , wm wiremessage.WireMessage ) error {
346
570
var err error
347
571
if c .dead {
@@ -415,6 +639,10 @@ func (c *connection) WriteWireMessage(ctx context.Context, wm wiremessage.WireMe
415
639
}
416
640
417
641
c .bumpIdleDeadline ()
642
+ err = c .commandStartedEvent (wm )
643
+ if err != nil {
644
+ return err
645
+ }
418
646
return nil
419
647
}
420
648
@@ -562,6 +790,11 @@ func (c *connection) ReadWireMessage(ctx context.Context) (wiremessage.WireMessa
562
790
}
563
791
564
792
c .bumpIdleDeadline ()
793
+ err = c .commandFinishedEvent (wm )
794
+ if err != nil {
795
+ return nil , err // TODO: do we care if monitoring fails?
796
+ }
797
+
565
798
return wm , nil
566
799
}
567
800
0 commit comments