Skip to content

Commit 3fd7112

Browse files
iwysiutsedgwick
authored andcommitted
GODRIVER-1962 unwrap errors in WithTransaction (mongodb#634)
1 parent d508609 commit 3fd7112

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

mongo/errors.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,21 @@ func unwrap(err error) error {
127127
return u.Unwrap()
128128
}
129129

130-
// IsNetworkError returns true if err is a network error
131-
func IsNetworkError(err error) bool {
130+
// errorHasLabel returns true if err contains the specified label
131+
func errorHasLabel(err error, label string) bool {
132132
for ; err != nil; err = unwrap(err) {
133133
if e, ok := err.(ServerError); ok {
134-
return e.HasErrorLabel("NetworkError")
134+
return e.HasErrorLabel(label)
135135
}
136136
}
137137
return false
138138
}
139139

140+
// IsNetworkError returns true if err is a network error
141+
func IsNetworkError(err error) bool {
142+
return errorHasLabel(err, "NetworkError")
143+
}
144+
140145
// MongocryptError represents an libmongocrypt error during client-side encryption.
141146
type MongocryptError struct {
142147
Code int32

mongo/session.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,8 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
193193
default:
194194
}
195195

196-
if cerr, ok := err.(CommandError); ok {
197-
if cerr.HasErrorLabel(driver.TransientTransactionError) {
198-
continue
199-
}
196+
if errorHasLabel(err, driver.TransientTransactionError) {
197+
continue
200198
}
201199
return res, err
202200
}

mongo/with_transactions_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ var (
3535
withTxnFailedEvents []*event.CommandFailedEvent
3636
)
3737

38+
type wrappedError struct {
39+
err error
40+
}
41+
42+
func (we wrappedError) Error() string {
43+
return we.err.Error()
44+
}
45+
46+
func (we wrappedError) Unwrap() error {
47+
return we.err
48+
}
49+
3850
func TestConvenientTransactions(t *testing.T) {
3951
client := setupConvenientTransactions(t)
4052
db := client.Database("TestConvenientTransactions")
@@ -381,6 +393,30 @@ func TestConvenientTransactions(t *testing.T) {
381393
// Assert that transaction is canceled within 500ms and not 2 seconds.
382394
assert.Soon(t, callback, 500*time.Millisecond)
383395
})
396+
t.Run("wrapped transient transaction error retried", func(t *testing.T) {
397+
sess, err := client.StartSession()
398+
assert.Nil(t, err, "StartSession error: %v", err)
399+
defer sess.EndSession(context.Background())
400+
401+
// returnError tracks whether or not the callback is being retried
402+
returnError := true
403+
res, err := sess.WithTransaction(context.Background(), func(sessCtx SessionContext) (interface{}, error) {
404+
if returnError {
405+
returnError = false
406+
return nil, wrappedError{
407+
CommandError{
408+
Name: "test Error",
409+
Labels: []string{driver.TransientTransactionError},
410+
},
411+
}
412+
}
413+
return false, nil
414+
})
415+
assert.Nil(t, err, "WithTransaction error: %v", err)
416+
resBool, ok := res.(bool)
417+
assert.True(t, ok, "expected result type %T, got %T", false, res)
418+
assert.False(t, resBool, "expected result false, got %v", resBool)
419+
})
384420
}
385421

386422
func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.ClientOptions) *Client {

0 commit comments

Comments
 (0)