Skip to content

Commit 1ce074c

Browse files
author
Divjot Arora
committed
GODRIVER-1696 CRUD update/replace bug fixes
1 parent 1962fcb commit 1ce074c

File tree

4 files changed

+50
-32
lines changed

4 files changed

+50
-32
lines changed

mongo/bulk_write.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ func createDeleteDoc(filter interface{}, collation *options.Collation, hint inte
281281
func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
282282
docs := make([]bsoncore.Document, len(batch.models))
283283
var hasHint bool
284+
var hasArrayFilters bool
284285
for i, model := range batch.models {
285286
var doc bsoncore.Document
286287
var err error
@@ -292,12 +293,14 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera
292293
hasHint = hasHint || (converted.Hint != nil)
293294
case *UpdateOneModel:
294295
doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, false,
295-
false, bw.collection.registry)
296+
true, bw.collection.registry)
296297
hasHint = hasHint || (converted.Hint != nil)
298+
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
297299
case *UpdateManyModel:
298300
doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, true,
299-
false, bw.collection.registry)
301+
true, bw.collection.registry)
300302
hasHint = hasHint || (converted.Hint != nil)
303+
hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
301304
}
302305
if err != nil {
303306
return operation.UpdateResult{}, err
@@ -310,7 +313,8 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera
310313
Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
311314
ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
312315
Database(bw.collection.db.name).Collection(bw.collection.name).
313-
Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint)
316+
Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.crypt).Hint(hasHint).
317+
ArrayFilters(hasArrayFilters)
314318
if bw.ordered != nil {
315319
op = op.Ordered(*bw.ordered)
316320
}
@@ -350,6 +354,7 @@ func createUpdateDoc(
350354
if err != nil {
351355
return nil, err
352356
}
357+
353358
updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
354359

355360
if multi {

mongo/collection.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,8 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc
550550
Session(sess).WriteConcern(wc).CommandMonitor(coll.client.monitor).
551551
ServerSelector(selector).ClusterClock(coll.client.clock).
552552
Database(coll.db.name).Collection(coll.name).
553-
Deployment(coll.client.deployment).Crypt(coll.client.crypt).Hint(uo.Hint != nil)
553+
Deployment(coll.client.deployment).Crypt(coll.client.crypt).Hint(uo.Hint != nil).
554+
ArrayFilters(uo.ArrayFilters != nil)
554555

555556
if uo.BypassDocumentValidation != nil && *uo.BypassDocumentValidation {
556557
op = op.BypassDocumentValidation(*uo.BypassDocumentValidation)
@@ -669,8 +670,8 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{},
669670
return nil, err
670671
}
671672

672-
if elem, err := r.IndexErr(0); err == nil && strings.HasPrefix(elem.Key(), "$") {
673-
return nil, errors.New("replacement document cannot contains keys beginning with '$")
673+
if err := ensureNoDollarKey(r); err != nil {
674+
return nil, err
674675
}
675676

676677
updateOptions := make([]*options.UpdateOptions, 0, len(opts))

mongo/mongo.go

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ func ensureDollarKeyv2(doc bsoncore.Document) error {
253253
return nil
254254
}
255255

256+
func ensureNoDollarKey(doc bsoncore.Document) error {
257+
if elem, err := doc.IndexErr(0); err == nil && strings.HasPrefix(elem.Key(), "$") {
258+
return errors.New("replacement document cannot contains keys beginning with '$")
259+
}
260+
261+
return nil
262+
}
263+
256264
func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsonx.Arr, error) {
257265
pipelineArr := bsonx.Arr{}
258266
switch t := pipeline.(type) {
@@ -333,7 +341,12 @@ func transformAggregatePipelinev2(registry *bsoncodec.Registry, pipeline interfa
333341
}
334342
}
335343

336-
func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, checkDocDollarKey bool) (bsoncore.Value, error) {
344+
func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, dollarKeysAllowed bool) (bsoncore.Value, error) {
345+
documentCheckerFunc := ensureDollarKeyv2
346+
if !dollarKeysAllowed {
347+
documentCheckerFunc = ensureNoDollarKey
348+
}
349+
337350
var u bsoncore.Value
338351
var err error
339352
switch t := update.(type) {
@@ -346,42 +359,27 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, chec
346359
return u, err
347360
}
348361

349-
if checkDocDollarKey {
350-
err = ensureDollarKeyv2(u.Data)
351-
}
352-
return u, err
362+
return u, documentCheckerFunc(u.Data)
353363
case bson.Raw:
354364
u.Type = bsontype.EmbeddedDocument
355365
u.Data = t
356-
if checkDocDollarKey {
357-
err = ensureDollarKeyv2(u.Data)
358-
}
359-
return u, err
366+
return u, documentCheckerFunc(u.Data)
360367
case bsoncore.Document:
361368
u.Type = bsontype.EmbeddedDocument
362369
u.Data = t
363-
if checkDocDollarKey {
364-
err = ensureDollarKeyv2(u.Data)
365-
}
366-
return u, err
370+
return u, documentCheckerFunc(u.Data)
367371
case []byte:
368372
u.Type = bsontype.EmbeddedDocument
369373
u.Data = t
370-
if checkDocDollarKey {
371-
err = ensureDollarKeyv2(u.Data)
372-
}
373-
return u, err
374+
return u, documentCheckerFunc(u.Data)
374375
case bsoncodec.Marshaler:
375376
u.Type = bsontype.EmbeddedDocument
376377
u.Data, err = t.MarshalBSON()
377378
if err != nil {
378379
return u, err
379380
}
380381

381-
if checkDocDollarKey {
382-
err = ensureDollarKeyv2(u.Data)
383-
}
384-
return u, err
382+
return u, documentCheckerFunc(u.Data)
385383
case bsoncodec.ValueMarshaler:
386384
u.Type, u.Data, err = t.MarshalBSONValue()
387385
if err != nil {
@@ -403,10 +401,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, chec
403401
return u, err
404402
}
405403

406-
if checkDocDollarKey {
407-
err = ensureDollarKeyv2(u.Data)
408-
}
409-
return u, err
404+
return u, documentCheckerFunc(u.Data)
410405
}
411406

412407
u.Type = bsontype.Array
@@ -418,7 +413,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, chec
418413
return u, err
419414
}
420415

421-
if err := ensureDollarKeyv2(doc); err != nil {
416+
if err := documentCheckerFunc(doc); err != nil {
422417
return u, err
423418
}
424419

x/mongo/driver/operation/update.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type Update struct {
3434
database string
3535
deployment driver.Deployment
3636
hint *bool
37+
arrayFilters *bool
3738
selector description.ServerSelector
3839
writeConcern *writeconcern.WriteConcern
3940
retry *driver.RetryMode
@@ -177,6 +178,11 @@ func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, e
177178
return nil, errUnacknowledgedHint
178179
}
179180
}
181+
if u.arrayFilters != nil && *u.arrayFilters {
182+
if desc.WireVersion == nil || !desc.WireVersion.Includes(6) {
183+
return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6")
184+
}
185+
}
180186

181187
return dst, nil
182188
}
@@ -204,6 +210,17 @@ func (u *Update) Hint(hint bool) *Update {
204210
return u
205211
}
206212

213+
// ArrayFilters is a flag to indicate that the update document contains an arrayFilters field. This option is only
214+
// supported on server versions 3.6 and higher. For servers < 3.6, the driver will return an error.
215+
func (u *Update) ArrayFilters(arrayFilters bool) *Update {
216+
if u == nil {
217+
u = new(Update)
218+
}
219+
220+
u.arrayFilters = &arrayFilters
221+
return u
222+
}
223+
207224
// Ordered sets ordered. If true, when a write fails, the operation will return the error, when
208225
// false write failures do not stop execution of the operation.
209226
func (u *Update) Ordered(ordered bool) *Update {

0 commit comments

Comments
 (0)