Skip to content

Commit 48ca246

Browse files
authored
GODRIVER-2232 Allow bsoncore.Array as type for aggregation pipelines (#807)
1 parent d6d6625 commit 48ca246

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

mongo/mongo.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,20 +210,43 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface
210210
return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind())
211211
}
212212

213-
aidx, arr := bsoncore.AppendArrayStart(nil)
214213
var hasOutputStage bool
215214
valLen := val.Len()
216215

216+
switch t := pipeline.(type) {
217217
// Explicitly forbid non-empty pipelines that are semantically single documents
218218
// and are implemented as slices.
219-
switch t := pipeline.(type) {
220219
case bson.D, bson.Raw, bsoncore.Document:
221220
if valLen > 0 {
222221
return nil, false,
223222
fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t)
224223
}
224+
// bsoncore.Arrays do not need to be transformed. Only check validity and presence of output stage.
225+
case bsoncore.Array:
226+
if err := t.Validate(); err != nil {
227+
return nil, false, err
228+
}
229+
230+
values, err := t.Values()
231+
if err != nil {
232+
return nil, false, err
233+
}
234+
235+
numVals := len(values)
236+
if numVals == 0 {
237+
return bsoncore.Document(t), false, nil
238+
}
239+
240+
// If not empty, check if first value of the last stage is $out or $merge.
241+
if lastStage, ok := values[numVals-1].DocumentOK(); ok {
242+
if elem, err := lastStage.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
243+
hasOutputStage = true
244+
}
245+
}
246+
return bsoncore.Document(t), hasOutputStage, nil
225247
}
226248

249+
aidx, arr := bsoncore.AppendArrayStart(nil)
227250
for idx := 0; idx < valLen; idx++ {
228251
doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, fmt.Sprintf("pipeline stage :%v", idx))
229252
if err != nil {

mongo/mongo_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,43 @@ func TestMongoHelpers(t *testing.T) {
120120
})
121121
})
122122
t.Run("transform aggregate pipeline", func(t *testing.T) {
123+
// []byte of [{{"$limit", 12345}}]
123124
index, arr := bsoncore.AppendArrayStart(nil)
124125
dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0")
125126
arr = bsoncore.AppendInt32Element(arr, "$limit", 12345)
126127
arr, _ = bsoncore.AppendDocumentEnd(arr, dindex)
127128
arr, _ = bsoncore.AppendArrayEnd(arr, index)
128129

130+
// []byte of {{"x", 1}}
129131
index, doc := bsoncore.AppendDocumentStart(nil)
130132
doc = bsoncore.AppendInt32Element(doc, "x", 1)
131133
doc, _ = bsoncore.AppendDocumentEnd(doc, index)
132134

135+
// bsoncore.Array of [{{"$merge", {}}}]
136+
mergeStage := bsoncore.NewDocumentBuilder().
137+
StartDocument("$merge").
138+
FinishDocument().
139+
Build()
140+
arrMergeStage := bsoncore.NewArrayBuilder().AppendDocument(mergeStage).Build()
141+
142+
fooStage := bsoncore.NewDocumentBuilder().AppendString("foo", "bar").Build()
143+
bazStage := bsoncore.NewDocumentBuilder().AppendString("baz", "qux").Build()
144+
outStage := bsoncore.NewDocumentBuilder().AppendString("$out", "myColl").Build()
145+
146+
// bsoncore.Array of [{{"foo", "bar"}}, {{"baz", "qux"}}, {{"$out", "myColl"}}]
147+
arrOutStage := bsoncore.NewArrayBuilder().
148+
AppendDocument(fooStage).
149+
AppendDocument(bazStage).
150+
AppendDocument(outStage).
151+
Build()
152+
153+
// bsoncore.Array of [{{"foo", "bar"}}, {{"$out", "myColl"}}, {{"baz", "qux"}}]
154+
arrMiddleOutStage := bsoncore.NewArrayBuilder().
155+
AppendDocument(fooStage).
156+
AppendDocument(outStage).
157+
AppendDocument(bazStage).
158+
Build()
159+
133160
testCases := []struct {
134161
name string
135162
pipeline interface{}
@@ -388,6 +415,46 @@ func TestMongoHelpers(t *testing.T) {
388415
false,
389416
nil,
390417
},
418+
{
419+
"bsoncore.Array/success",
420+
bsoncore.Array(arr),
421+
bson.A{
422+
bson.D{{"$limit", int32(12345)}},
423+
},
424+
false,
425+
nil,
426+
},
427+
{
428+
"bsoncore.Array/mergeStage",
429+
arrMergeStage,
430+
bson.A{
431+
bson.D{{"$merge", bson.D{}}},
432+
},
433+
true,
434+
nil,
435+
},
436+
{
437+
"bsoncore.Array/outStage",
438+
arrOutStage,
439+
bson.A{
440+
bson.D{{"foo", "bar"}},
441+
bson.D{{"baz", "qux"}},
442+
bson.D{{"$out", "myColl"}},
443+
},
444+
true,
445+
nil,
446+
},
447+
{
448+
"bsoncore.Array/middleOutStage",
449+
arrMiddleOutStage,
450+
bson.A{
451+
bson.D{{"foo", "bar"}},
452+
bson.D{{"$out", "myColl"}},
453+
bson.D{{"baz", "qux"}},
454+
},
455+
false,
456+
nil,
457+
},
391458
}
392459

393460
for _, tc := range testCases {

0 commit comments

Comments
 (0)