Skip to content

GODRIVER-1548 Add gridfs.File type #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions mongo/gridfs/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
Expand Down Expand Up @@ -362,32 +362,23 @@ func (b *Bucket) openDownloadStream(filter interface{}, opts ...*options.FindOpt
return nil, err
}

fileLenElem, err := cursor.Current.LookupErr("length")
if err != nil {
return nil, err
}
fileIDElem, err := cursor.Current.LookupErr("_id")
if err != nil {
return nil, err
}

var fileLen int64
switch fileLenElem.Type {
case bsontype.Int32:
fileLen = int64(fileLenElem.Int32())
default:
fileLen = fileLenElem.Int64()
// Unmarshal the data into a File instance, which can be passed to newDownloadStream. The _id value has to be
// parsed out separately because "_id" will not match the File.ID field and we want to avoid exposing BSON tags
// in the File type. After parsing it, use RawValue.Unmarshal to ensure File.ID is set to the appropriate value.
var foundFile File
if err = cursor.Decode(&foundFile); err != nil {
return nil, fmt.Errorf("error decoding files collection document: %v", err)
}

if fileLen == 0 {
return newDownloadStream(nil, b.chunkSize, 0), nil
if foundFile.Length == 0 {
return newDownloadStream(nil, b.chunkSize, &foundFile), nil
}

chunksCursor, err := b.findChunks(ctx, fileIDElem)
chunksCursor, err := b.findChunks(ctx, foundFile.ID)
if err != nil {
return nil, err
}
return newDownloadStream(chunksCursor, b.chunkSize, int64(fileLen)), nil
return newDownloadStream(chunksCursor, b.chunkSize, &foundFile), nil
}

func deadlineContext(deadline time.Time) (context.Context, context.CancelFunc) {
Expand Down
71 changes: 68 additions & 3 deletions mongo/gridfs/download_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"math"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)

Expand All @@ -37,18 +38,77 @@ type DownloadStream struct {
expectedChunk int32 // index of next expected chunk
readDeadline time.Time
fileLen int64

// The pointer returned by GetFile. This should not be used in the actual DownloadStream code outside of the
// newDownloadStream constructor because the values can be mutated by the user after calling GetFile. Instead,
// any values needed in the code should be stored separately and copied over in the constructor.
file *File
}

// File represents a file stored in GridFS. This type can be used to access file information when downloading using the
// DownloadStream.GetFile method.
type File struct {
// ID is the file's ID. This will match the file ID specified when uploading the file. If an upload helper that
// does not require a file ID was used, this field will be a primitive.ObjectID.
ID interface{}

// Length is the length of this file in bytes.
Length int64

// ChunkSize is the maximum number of bytes for each chunk in this file.
ChunkSize int32

// UploadDate is the time this file was added to GridFS in UTC.
UploadDate time.Time

// Name is the name of this file.
Name string

// Metadata is additional data that was specified when creating this file. This field can be unmarshalled into a
// custom type using the bson.Unmarshal family of functions.
Metadata bson.Raw
}

var _ bson.Unmarshaler = (*File)(nil)

// unmarshalFile is a temporary type used to unmarshal documents from the files collection and can be transformed into
// a File instance. This type exists to avoid adding BSON struct tags to the exported File type.
type unmarshalFile struct {
ID interface{} `bson:"_id"`
Length int64 `bson:"length"`
ChunkSize int32 `bson:"chunkSize"`
UploadDate time.Time `bson:"uploadDate"`
Name string `bson:"filename"`
Metadata bson.Raw `bson:"metadata"`
}

func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, fileLen int64) *DownloadStream {
numChunks := int32(math.Ceil(float64(fileLen) / float64(chunkSize)))
// UnmarshalBSON implements the bson.Unmarshaler interface.
func (f *File) UnmarshalBSON(data []byte) error {
var temp unmarshalFile
if err := bson.Unmarshal(data, &temp); err != nil {
return err
}

f.ID = temp.ID
f.Length = temp.Length
f.ChunkSize = temp.ChunkSize
f.UploadDate = temp.UploadDate
f.Name = temp.Name
f.Metadata = temp.Metadata
return nil
}

func newDownloadStream(cursor *mongo.Cursor, chunkSize int32, file *File) *DownloadStream {
numChunks := int32(math.Ceil(float64(file.Length) / float64(chunkSize)))

return &DownloadStream{
numChunks: numChunks,
chunkSize: chunkSize,
cursor: cursor,
buffer: make([]byte, chunkSize),
done: cursor == nil,
fileLen: fileLen,
fileLen: file.Length,
file: file,
}
}

Expand Down Expand Up @@ -161,6 +221,11 @@ func (ds *DownloadStream) Skip(skip int64) (int64, error) {
return skip, nil
}

// GetFile returns a File object representing the file being downloaded.
func (ds *DownloadStream) GetFile() *File {
return ds.file
}

func (ds *DownloadStream) fillBuffer(ctx context.Context) error {
if !ds.cursor.Next(ctx) {
ds.done = true
Expand Down
69 changes: 69 additions & 0 deletions mongo/integration/gridfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,75 @@ func TestGridFS(x *testing.T) {
}
})

mt.RunOpts("download", noClientOpts, func(mt *mtest.T) {
mt.RunOpts("get file data", noClientOpts, func(mt *mtest.T) {
// Tests for the DownloadStream.GetFile method.

fileName := "get-file-data-test"
fileData := []byte{1, 2, 3, 4}
fileMetadata := bson.D{{"k1", "v1"}, {"k2", "v2"}}
rawMetadata, err := bson.Marshal(fileMetadata)
assert.Nil(mt, err, "Marshal error: %v", err)
uploadOpts := options.GridFSUpload().SetMetadata(fileMetadata)

testCases := []struct {
name string
fileID interface{}
}{
{"default ID", nil},
{"custom ID type", "customID"},
}
for _, tc := range testCases {
mt.Run(tc.name, func(mt *mtest.T) {
// Create a new GridFS bucket.
bucket, err := gridfs.NewBucket(mt.DB)
assert.Nil(mt, err, "NewBucket error: %v", err)
defer func() { _ = bucket.Drop() }()

// Upload the file and store the uploaded file ID.
uploadedFileID := tc.fileID
dataReader := bytes.NewReader(fileData)
if uploadedFileID == nil {
uploadedFileID, err = bucket.UploadFromStream(fileName, dataReader, uploadOpts)
} else {
err = bucket.UploadFromStreamWithID(tc.fileID, fileName, dataReader, uploadOpts)
}
assert.Nil(mt, err, "error uploading file: %v", err)

// The uploadDate field is calculated when the upload is complete. Manually fetch it from the
// fs.files collection to use in assertions.
filesColl := mt.DB.Collection("fs.files")
uploadedFileDoc, err := filesColl.FindOne(mtest.Background, bson.D{}).DecodeBytes()
assert.Nil(mt, err, "FindOne error: %v", err)
uploadTime := uploadedFileDoc.Lookup("uploadDate").Time().UTC()

expectedFile := &gridfs.File{
ID: uploadedFileID,
Length: int64(len(fileData)),
ChunkSize: gridfs.DefaultChunkSize,
UploadDate: uploadTime,
Name: fileName,
Metadata: rawMetadata,
}
// For both methods that create a DownloadStream, open a stream and compare the file given by the
// stream to the expected File object.
mt.RunOpts("OpenDownloadStream", noClientOpts, func(mt *mtest.T) {
downloadStream, err := bucket.OpenDownloadStream(uploadedFileID)
assert.Nil(mt, err, "OpenDownloadStream error: %v", err)
actualFile := downloadStream.GetFile()
assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile)
})
mt.RunOpts("OpenDownloadStreamByName", noClientOpts, func(mt *mtest.T) {
downloadStream, err := bucket.OpenDownloadStreamByName(fileName)
assert.Nil(mt, err, "OpenDownloadStream error: %v", err)
actualFile := downloadStream.GetFile()
assert.Equal(mt, expectedFile, actualFile, "expected file %v, got %v", expectedFile, actualFile)
})
})
}
})
})

mt.RunOpts("round trip", mtest.NewOptions().MaxServerVersion("3.6"), func(mt *mtest.T) {
skipRoundTripTest(mt)
oneK := 1024
Expand Down