Skip to content

Commit 32066ed

Browse files
author
Divjot Arora
authored
GODRIVER-1548 Add gridfs.File type (#353)
1 parent a2fd877 commit 32066ed

File tree

3 files changed

+148
-23
lines changed

3 files changed

+148
-23
lines changed

mongo/gridfs/bucket.go

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ import (
1010
"bytes"
1111
"context"
1212
"errors"
13+
"fmt"
1314
"io"
1415
"time"
1516

1617
"go.mongodb.org/mongo-driver/bson"
17-
"go.mongodb.org/mongo-driver/bson/bsontype"
1818
"go.mongodb.org/mongo-driver/bson/primitive"
1919
"go.mongodb.org/mongo-driver/mongo"
2020
"go.mongodb.org/mongo-driver/mongo/options"
@@ -362,32 +362,23 @@ func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOpt
362362
return nil, err
363363
}
364364

365-
fileLenElem, err := cursor.Current.LookupErr("length")
366-
if err != nil {
367-
return nil, err
368-
}
369-
fileIDElem, err := cursor.Current.LookupErr("_id")
370-
if err != nil {
371-
return nil, err
372-
}
373-
374-
var fileLen int64
375-
switch fileLenElem.Type {
376-
case bsontype.Int32:
377-
fileLen = int64(fileLenElem.Int32())
378-
default:
379-
fileLen = fileLenElem.Int64()
365+
// Unmarshal the data into a File instance, which can be passed to newDownloadStream. The _id value has to be
366+
// parsed out separately because "_id" will not match the File.ID field and we want to avoid exposing BSON tags
367+
// in the File type. After parsing it, use RawValue.Unmarshal to ensure File.ID is set to the appropriate value.
368+
var foundFile File
369+
if err = cursor.Decode(&foundFile); err != nil {
370+
return nil, fmt.Errorf("error decoding files collection document: %v", err)
380371
}
381372

382-
if fileLen == 0 {
383-
return newDownloadStream(nil, b.chunkSize, 0), nil
373+
if foundFile.Length == 0 {
374+
return newDownloadStream(nil, b.chunkSize, &foundFile), nil
384375
}
385376

386-
chunksCursor, err := b.findChunks(ctx, fileIDElem)
377+
chunksCursor, err := b.findChunks(ctx, foundFile.ID)
387378
if err != nil {
388379
return nil, err
389380
}
390-
return newDownloadStream(chunksCursor, b.chunkSize, int64(fileLen)), nil
381+
return newDownloadStream(chunksCursor, b.chunkSize, &foundFile), nil
391382
}
392383

393384
func deadlineContext(deadline time.Time) (context.Context, context.CancelFunc) {

mongo/gridfs/download_stream.go

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"math"
1414
"time"
1515

16+
"go.mongodb.org/mongo-driver/bson"
1617
"go.mongodb.org/mongo-driver/mongo"
1718
)
1819

@@ -37,18 +38,77 @@ type DownloadStream struct {
3738
expectedChunk int32 // index of next expected chunk
3839
readDeadline time.Time
3940
fileLen int64
41+
42+
// The pointer returned by GetFile. This should not be used in the actual DownloadStream code outside of the
43+
// newDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead,
44+
// any values needed in the code should be stored separately and copied over in the constructor.
45+
file *File
46+
}
47+
48+
// File represents a file stored in GridFS. This type can be used to access file information when downloading using the
49+
// DownloadStream.GetFile method.
50+
type File struct {
51+
// ID is the file's ID. This will match the file ID specified when uploading the file. If an upload helper that
52+
// does not require a file ID was used, this field will be a primitive.ObjectID.
53+
ID interface{}
54+
55+
// Length is the length of this file in bytes.
56+
Length int64
57+
58+
// ChunkSize is the maximum number of bytes for each chunk in this file.
59+
ChunkSize int32
60+
61+
// UploadDate is the time this file was added to GridFS in UTC.
62+
UploadDate time.Time
63+
64+
// Name is the name of this file.
65+
Name string
66+
67+
// Metadata is additional data that was specified when creating this file. This field can be unmarshalled into a
68+
// custom type using the bson.Unmarshal family of functions.
69+
Metadata bson.Raw
70+
}
71+
72+
var _ bson.Unmarshaler = (*File)(nil)
73+
74+
// unmarshalFile is a temporary type used to unmarshal documents from the files collection and can be transformed into
75+
// a File instance. This type exists to avoid adding BSON struct tags to the exported File type.
76+
type unmarshalFile struct {
77+
ID interface{} `bson:"_id"`
78+
Length int64 `bson:"length"`
79+
ChunkSize int32 `bson:"chunkSize"`
80+
UploadDate time.Time `bson:"uploadDate"`
81+
Name string `bson:"filename"`
82+
Metadata bson.Raw `bson:"metadata"`
4083
}
4184

42-
func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, fileLen int64) *DownloadStream {
43-
numChunks := int32(math.Ceil(float64(fileLen) / float64(chunkSize)))
85+
// UnmarshalBSON implements the bson.Unmarshaler interface.
86+
func (f *File) UnmarshalBSON(data []byte) error {
87+
var temp unmarshalFile
88+
if err := bson.Unmarshal(data, &temp); err != nil {
89+
return err
90+
}
91+
92+
f.ID = temp.ID
93+
f.Length = temp.Length
94+
f.ChunkSize = temp.ChunkSize
95+
f.UploadDate = temp.UploadDate
96+
f.Name = temp.Name
97+
f.Metadata = temp.Metadata
98+
return nil
99+
}
100+
101+
func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
102+
numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize)))
44103

45104
return &DownloadStream{
46105
numChunks: numChunks,
47106
chunkSize: chunkSize,
48107
cursor: cursor,
49108
buffer: make([]byte, chunkSize),
50109
done: cursor == nil,
51-
fileLen: fileLen,
110+
fileLen: file.Length,
111+
file: file,
52112
}
53113
}
54114

@@ -161,6 +221,11 @@ func (ds *DownloadStream) Skip(skip int64) (int64, error) {
161221
return skip, nil
162222
}
163223

224+
// GetFile returns a File object representing the file being downloaded.
225+
func (ds *DownloadStream) GetFile() *File {
226+
return ds.file
227+
}
228+
164229
func (ds *DownloadStream) fillBuffer(ctx context.Context) error {
165230
if !ds.cursor.Next(ctx) {
166231
ds.done = true

mongo/integration/gridfs_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,75 @@ func TestGridFS(x *testing.T) {
181181
}
182182
})
183183

184+
mt.RunOpts("download", noClientOpts, func(mt *mtest.T) {
185+
mt.RunOpts("get file data", noClientOpts, func(mt *mtest.T) {
186+
// Tests for the DownloadStream.GetFile method.
187+
188+
fileName := "get-file-data-test"
189+
fileData := []byte{1, 2, 3, 4}
190+
fileMetadata := bson.D{{"k1", "v1"}, {"k2", "v2"}}
191+
rawMetadata, err := bson.Marshal(fileMetadata)
192+
assert.Nil(mt, err, "Marshal error: %v", err)
193+
uploadOpts := options.GridFSUpload().SetMetadata(fileMetadata)
194+
195+
testCases := []struct {
196+
name string
197+
fileID interface{}
198+
}{
199+
{"default ID", nil},
200+
{"custom ID type", "customID"},
201+
}
202+
for _, tc := range testCases {
203+
mt.Run(tc.name, func(mt *mtest.T) {
204+
// Create a new GridFS bucket.
205+
bucket, err := gridfs.NewBucket(mt.DB)
206+
assert.Nil(mt, err, "NewBucket error: %v", err)
207+
defer func() { _ = bucket.Drop() }()
208+
209+
// Upload the file and store the uploaded file ID.
210+
uploadedFileID := tc.fileID
211+
dataReader := bytes.NewReader(fileData)
212+
if uploadedFileID == nil {
213+
uploadedFileID, err = bucket.UploadFromStream(fileName, dataReader, uploadOpts)
214+
} else {
215+
err = bucket.UploadFromStreamWithID(tc.fileID, fileName, dataReader, uploadOpts)
216+
}
217+
assert.Nil(mt, err, "error uploading file: %v", err)
218+
219+
// The uploadDate field is calculated when the upload is complete. Manually fetch it from the
220+
// fs.files collection to use in assertions.
221+
filesColl := mt.DB.Collection("fs.files")
222+
uploadedFileDoc, err := filesColl.FindOne(mtest.Background, bson.D{}).DecodeBytes()
223+
assert.Nil(mt, err, "FindOne error: %v", err)
224+
uploadTime := uploadedFileDoc.Lookup("uploadDate").Time().UTC()
225+
226+
expectedFile := &gridfs.File{
227+
ID: uploadedFileID,
228+
Length: int64(len(fileData)),
229+
ChunkSize: gridfs.DefaultChunkSize,
230+
UploadDate: uploadTime,
231+
Name: fileName,
232+
Metadata: rawMetadata,
233+
}
234+
// For both methods that create a DownloadStream, open a stream and compare the file given by the
235+
// stream to the expected File object.
236+
mt.RunOpts("OpenDownloadStream", noClientOpts, func(mt *mtest.T) {
237+
downloadStream, err := bucket.OpenDownloadStream(uploadedFileID)
238+
assert.Nil(mt, err, "OpenDownloadStream error: %v", err)
239+
actualFile := downloadStream.GetFile()
240+
assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile)
241+
})
242+
mt.RunOpts("OpenDownloadStreamByName", noClientOpts, func(mt *mtest.T) {
243+
downloadStream, err := bucket.OpenDownloadStreamByName(fileName)
244+
assert.Nil(mt, err, "OpenDownloadStream error: %v", err)
245+
actualFile := downloadStream.GetFile()
246+
assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile)
247+
})
248+
})
249+
}
250+
})
251+
})
252+
184253
mt.RunOpts("round trip", mtest.NewOptions().MaxServerVersion("3.6"), func(mt *mtest.T) {
185254
skipRoundTripTest(mt)
186255
oneK := 1024

0 commit comments

Comments
 (0)