Skip to content

Commit 0bef69c

Browse files
authored
GODRIVER-2147 remove session from context in internal CSFLE operations (#762)
1 parent 22266fc commit 0bef69c

File tree

3 files changed

+247
-0
lines changed

3 files changed

+247
-0
lines changed

mongo/crypt_retrievers.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ type keyRetriever struct {
1818
}
1919

2020
func (kr *keyRetriever) cryptKeys(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error) {
21+
// Remove the explicit session from the context if one is set.
22+
// The explicit session may be from a different client.
23+
ctx = NewSessionContext(ctx, nil)
2124
cursor, err := kr.coll.Find(ctx, filter)
2225
if err != nil {
2326
return nil, EncryptionKeyVaultError{Wrapped: err}
@@ -43,6 +46,9 @@ type collInfoRetriever struct {
4346
}
4447

4548
func (cir *collInfoRetriever) cryptCollInfo(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error) {
49+
// Remove the explicit session from the context if one is set.
50+
// The explicit session may be from a different client.
51+
ctx = NewSessionContext(ctx, nil)
4652
cursor, err := cir.client.Database(db).ListCollections(ctx, filter)
4753
if err != nil {
4854
return nil, err
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
// Copyright (C) MongoDB, Inc. 2021-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
// +build cse
8+
9+
package integration
10+
11+
import (
12+
"context"
13+
"testing"
14+
15+
"go.mongodb.org/mongo-driver/bson"
16+
"go.mongodb.org/mongo-driver/bson/bsontype"
17+
"go.mongodb.org/mongo-driver/bson/primitive"
18+
"go.mongodb.org/mongo-driver/event"
19+
"go.mongodb.org/mongo-driver/internal/testutil"
20+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
21+
"go.mongodb.org/mongo-driver/mongo"
22+
"go.mongodb.org/mongo-driver/mongo/integration/mtest"
23+
"go.mongodb.org/mongo-driver/mongo/options"
24+
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25+
)
26+
27+
// createDataKeyAndEncrypt creates a data key with the alternate name @keyName.
28+
// Returns a ciphertext encrypted with the data key as test data.
29+
func createDataKeyAndEncrypt(mt *mtest.T, keyName string) primitive.Binary {
30+
mt.Helper()
31+
32+
kvClientOpts := options.Client().
33+
ApplyURI(mtest.ClusterURI()).
34+
SetReadConcern(mtest.MajorityRc).
35+
SetWriteConcern(mtest.MajorityWc)
36+
37+
testutil.AddTestServerAPIVersion(kvClientOpts)
38+
39+
kmsProvidersMap := map[string]map[string]interface{}{
40+
"local": {"key": localMasterKey},
41+
}
42+
43+
kvClient, err := mongo.Connect(mtest.Background, kvClientOpts)
44+
defer kvClient.Disconnect(mtest.Background)
45+
assert.Nil(mt, err, "Connect error: %v", err)
46+
47+
err = kvClient.Database("keyvault").Collection("datakeys").Drop(mtest.Background)
48+
assert.Nil(mt, err, "Drop error: %v", err)
49+
50+
ceOpts := options.ClientEncryption().
51+
SetKmsProviders(kmsProvidersMap).
52+
SetKeyVaultNamespace("keyvault.datakeys")
53+
54+
ce, err := mongo.NewClientEncryption(kvClient, ceOpts)
55+
assert.Nil(mt, err, "NewClientEncryption error: %v", err)
56+
57+
dkOpts := options.DataKey().SetKeyAltNames([]string{keyName})
58+
_, err = ce.CreateDataKey(mtest.Background, "local", dkOpts)
59+
assert.Nil(mt, err, "CreateDataKey error: %v", err)
60+
61+
in := bson.RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "test")}
62+
eOpts := options.Encrypt().
63+
SetAlgorithm("AEAD_AES_256_CBC_HMAC_SHA_512-Random").
64+
SetKeyAltName(keyName)
65+
66+
ciphertext, err := ce.Encrypt(mtest.Background, in, eOpts)
67+
assert.Nil(mt, err, "Encrypt error: %v", err)
68+
return ciphertext
69+
}
70+
71+
func getLsid(mt *mtest.T, doc bson.Raw) bson.Raw {
72+
mt.Helper()
73+
74+
lsid, err := doc.LookupErr("lsid")
75+
assert.Nil(mt, err, "expected lsid in document: %v", doc)
76+
lsidDoc, ok := lsid.DocumentOK()
77+
assert.True(mt, ok, "expected lsid to be document, but got: %v", lsid)
78+
return lsidDoc
79+
}
80+
81+
func makeMonitor(mt *mtest.T, captured *[]event.CommandStartedEvent) *event.CommandMonitor {
82+
mt.Helper()
83+
assert.NotNil(mt, captured, "captured is nil")
84+
85+
return &event.CommandMonitor{
86+
Started: func(_ context.Context, cse *event.CommandStartedEvent) {
87+
assert.NotNil(mt, cse, "expected non-Nil CommandStartedEvent")
88+
*captured = append(*captured, *cse)
89+
},
90+
}
91+
}
92+
93+
func TestClientSideEncryptionWithExplicitSessions(t *testing.T) {
94+
verifyClientSideEncryptionVarsSet(t)
95+
mt := mtest.New(t, mtest.NewOptions().MinServerVersion("4.2").Enterprise(true).CreateClient(false))
96+
defer mt.Close()
97+
98+
kmsProvidersMap := map[string]map[string]interface{}{
99+
"local": {"key": localMasterKey},
100+
}
101+
102+
schema := bson.D{
103+
{"bsonType", "object"},
104+
{"properties", bson.D{
105+
{"encryptMe", bson.D{
106+
{"encrypt", bson.D{
107+
{"keyId", "/keyName"},
108+
{"bsonType", "string"},
109+
{"algorithm", "AEAD_AES_256_CBC_HMAC_SHA_512-Random"},
110+
}},
111+
}},
112+
}},
113+
}
114+
schemaMap := map[string]interface{}{"db.coll": schema}
115+
116+
mt.Run("automatic encryption", func(mt *mtest.T) {
117+
createDataKeyAndEncrypt(mt, "myKey")
118+
119+
aeOpts := options.AutoEncryption().
120+
SetKmsProviders(kmsProvidersMap).
121+
SetKeyVaultNamespace("keyvault.datakeys").
122+
SetSchemaMap(schemaMap)
123+
124+
var capturedEvents []event.CommandStartedEvent
125+
126+
clientOpts := options.Client().
127+
ApplyURI(mtest.ClusterURI()).
128+
SetReadConcern(mtest.MajorityRc).
129+
SetWriteConcern(mtest.MajorityWc).
130+
SetAutoEncryptionOptions(aeOpts).
131+
SetMonitor(makeMonitor(mt, &capturedEvents))
132+
133+
testutil.AddTestServerAPIVersion(clientOpts)
134+
135+
client, err := mongo.Connect(mtest.Background, clientOpts)
136+
assert.Nil(mt, err, "Connect error: %v", err)
137+
defer client.Disconnect(mtest.Background)
138+
139+
coll := client.Database("db").Collection("coll")
140+
err = coll.Drop(mtest.Background)
141+
assert.Nil(mt, err, "Drop error: %v", err)
142+
143+
session, err := client.StartSession()
144+
assert.Nil(mt, err, "StartSession error: %v", err)
145+
sessionCtx := mongo.NewSessionContext(mtest.Background, session)
146+
147+
capturedEvents = make([]event.CommandStartedEvent, 0)
148+
_, err = coll.InsertOne(sessionCtx, bson.D{{"encryptMe", "test"}, {"keyName", "myKey"}})
149+
assert.Nil(mt, err, "InsertOne error: %v", err)
150+
151+
assert.Equal(mt, len(capturedEvents), 2, "expected 2 events, got %v", len(capturedEvents))
152+
153+
// Assert the first event is a find on the keyvault.datakeys collection.
154+
event := capturedEvents[0]
155+
assert.Equal(mt, event.CommandName, "find", "expected command find, got %q", event.CommandName)
156+
assert.Equal(mt, event.DatabaseName, "keyvault", "expected find on keyvault, got %q", event.DatabaseName)
157+
158+
// Assert the find used an implicit session with an lsid != session.ID()
159+
lsid := getLsid(mt, event.Command)
160+
assert.Nil(mt, err, "lsid not found in %v", event.Command)
161+
assert.NotEqual(mt, lsid, session.ID(), "expected different lsid, but got %v", lsid)
162+
163+
// Assert the second event is the original insert.
164+
event = capturedEvents[1]
165+
assert.Equal(mt, event.CommandName, "insert", "expected command insert, got %q", event.CommandName)
166+
167+
// Assert the insert used the explicit session.
168+
lsid = getLsid(mt, event.Command)
169+
assert.Nil(mt, err, "lsid not found on %v", event.Command)
170+
assert.Equal(mt, lsid, session.ID(), "expected lsid %v, but got %v", session.ID(), lsid)
171+
172+
// Check that encryptMe is encrypted.
173+
encryptMe, err := event.Command.LookupErr("documents", "0", "encryptMe")
174+
assert.Nil(mt, err, "could not find encryptMe in %v", event.Command)
175+
assert.Equal(mt, encryptMe.Type, bson.TypeBinary, "expected Binary, got %v", encryptMe.Type)
176+
})
177+
178+
mt.Run("automatic decryption", func(mt *mtest.T) {
179+
ciphertext := createDataKeyAndEncrypt(mt, "myKey")
180+
181+
aeOpts := options.AutoEncryption().
182+
SetKmsProviders(kmsProvidersMap).
183+
SetKeyVaultNamespace("keyvault.datakeys").
184+
SetBypassAutoEncryption(true)
185+
186+
var capturedEvents []event.CommandStartedEvent
187+
188+
clientOpts := options.Client().
189+
ApplyURI(mtest.ClusterURI()).
190+
SetReadConcern(mtest.MajorityRc).
191+
SetWriteConcern(mtest.MajorityWc).
192+
SetAutoEncryptionOptions(aeOpts).
193+
SetMonitor(makeMonitor(mt, &capturedEvents))
194+
195+
testutil.AddTestServerAPIVersion(clientOpts)
196+
197+
client, err := mongo.Connect(mtest.Background, clientOpts)
198+
assert.Nil(mt, err, "Connect error: %v", err)
199+
defer client.Disconnect(mtest.Background)
200+
201+
coll := client.Database("db").Collection("coll")
202+
err = coll.Drop(mtest.Background)
203+
assert.Nil(mt, err, "Drop error: %v", err)
204+
_, err = coll.InsertOne(mtest.Background, bson.D{{"encryptMe", ciphertext}})
205+
assert.Nil(mt, err, "InsertOne error: %v", err)
206+
207+
session, err := client.StartSession()
208+
assert.Nil(mt, err, "StartSession error: %v", err)
209+
sessionCtx := mongo.NewSessionContext(mtest.Background, session)
210+
211+
capturedEvents = make([]event.CommandStartedEvent, 0)
212+
res := coll.FindOne(sessionCtx, bson.D{{}})
213+
assert.Nil(mt, res.Err(), "FindOne error: %v", res.Err())
214+
215+
assert.Equal(mt, len(capturedEvents), 2, "expected 2 events, got %v", len(capturedEvents))
216+
217+
// Assert the first event is the original find.
218+
event := capturedEvents[0]
219+
assert.Equal(mt, event.CommandName, "find", "expected command find, got %q", event.CommandName)
220+
assert.Equal(mt, event.DatabaseName, "db", "expected find on db, got %q", event.DatabaseName)
221+
222+
// Assert the find used the explicit session
223+
lsid := getLsid(mt, event.Command)
224+
assert.Nil(mt, err, "lsid not found on %v", event.Command)
225+
assert.Equal(mt, lsid, session.ID(), "expected lsid %v, but got %v", session.ID(), lsid)
226+
227+
// Assert the second event is the find on the keyvault.datakeys collection.
228+
event = capturedEvents[1]
229+
assert.Equal(mt, event.CommandName, "find", "expected command find, got %q", event.CommandName)
230+
assert.Equal(mt, event.DatabaseName, "keyvault", "expected find on keyvault, got %q", event.DatabaseName)
231+
232+
// Assert the find used an implicit session with an lsid != session.ID()
233+
lsid = getLsid(mt, event.Command)
234+
assert.Nil(mt, err, "lsid not found on %v", event.Command)
235+
assert.NotEqual(mt, lsid, session.ID(), "expected different lsid, but got %v", lsid)
236+
})
237+
}

mongo/mongocryptd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ func newMcryptClient(opts *options.AutoEncryptionOptions) (*mcryptClient, error)
7777

7878
// markCommand executes the given command on mongocryptd.
7979
func (mc *mcryptClient) markCommand(ctx context.Context, dbName string, cmd bsoncore.Document) (bsoncore.Document, error) {
80+
// Remove the explicit session from the context if one is set.
81+
// The explicit session will be from a different client.
82+
// If an explicit session is set, it is applied after automatic encryption.
83+
ctx = NewSessionContext(ctx, nil)
8084
db := mc.client.Database(dbName, databaseOpts)
8185

8286
res, err := db.RunCommand(ctx, cmd).DecodeBytes()

0 commit comments

Comments
 (0)