Skip to content

refactor(NODE-3332): withTransaction uses async/await #4053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 15, 2024
Merged
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 93 additions & 108 deletions src/sessions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/* eslint-disable github/no-then */
import { promisify } from 'util';

import { Binary, type Document, Long, type Timestamp } from './bson';
import type { CommandOptions, Connection } from './cmap/connection';
Expand Down Expand Up @@ -40,7 +39,6 @@ import {
import {
ByteUtils,
calculateDurationInMs,
type Callback,
commandSupportsReadConcern,
isPromiseLike,
List,
Expand Down Expand Up @@ -425,14 +423,14 @@ export class ClientSession extends TypedEventEmitter<ClientSessionEvents> {
* Commits the currently active transaction in this session.
*/
async commitTransaction(): Promise<void> {
return await endTransactionAsync(this, 'commitTransaction');
return await endTransaction(this, 'commitTransaction');
}

/**
* Aborts the currently active transaction in this session.
*/
async abortTransaction(): Promise<void> {
return await endTransactionAsync(this, 'abortTransaction');
return await endTransaction(this, 'abortTransaction');
}

/**
Expand Down Expand Up @@ -555,33 +553,33 @@ function isMaxTimeMSExpiredError(err: MongoError) {
);
}

function attemptTransactionCommit<T>(
async function attemptTransactionCommit<T>(
session: ClientSession,
startTime: number,
fn: WithTransactionCallback<T>,
result: any,
options: TransactionOptions
): Promise<T> {
return session.commitTransaction().then(
() => result,
(err: MongoError) => {
if (
err instanceof MongoError &&
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) &&
!isMaxTimeMSExpiredError(err)
) {
if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) {
return attemptTransactionCommit(session, startTime, fn, result, options);
}

if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) {
return attemptTransaction(session, startTime, fn, options);
}
try {
await session.commitTransaction();
return result;
} catch (err) {
if (
err instanceof MongoError &&
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) &&
!isMaxTimeMSExpiredError(err)
) {
if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) {
return await attemptTransactionCommit(session, startTime, fn, result, options);
}

throw err;
if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) {
return await attemptTransaction(session, startTime, fn, options);
}
}
);

throw err;
}
}

const USER_EXPLICIT_TXN_END_STATES = new Set<TxnState>([
Expand All @@ -594,7 +592,7 @@ function userExplicitlyEndedTransaction(session: ClientSession) {
return USER_EXPLICIT_TXN_END_STATES.has(session.transaction.state);
}

function attemptTransaction<T>(
async function attemptTransaction<T>(
session: ClientSession,
startTime: number,
fn: WithTransactionCallback<T>,
Expand All @@ -610,65 +608,48 @@ function attemptTransaction<T>(
}

if (!isPromiseLike(promise)) {
session.abortTransaction().catch(() => null);
return Promise.reject(
new MongoInvalidArgumentError('Function provided to `withTransaction` must return a Promise')
await session.abortTransaction().catch(() => null);
throw new MongoInvalidArgumentError(
'Function provided to `withTransaction` must return a Promise'
);
}

return promise.then(
result => {
if (userExplicitlyEndedTransaction(session)) {
return result;
}

return attemptTransactionCommit(session, startTime, fn, result, options);
},
err => {
function maybeRetryOrThrow(err: MongoError): Promise<any> {
if (
err instanceof MongoError &&
err.hasErrorLabel(MongoErrorLabel.TransientTransactionError) &&
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT)
) {
return attemptTransaction(session, startTime, fn, options);
}

if (isMaxTimeMSExpiredError(err)) {
err.addErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult);
}

throw err;
}
try {
const result = await promise;
if (userExplicitlyEndedTransaction(session)) {
return result;
}
return await attemptTransactionCommit(session, startTime, fn, result, options);
} catch (err) {
if (session.inTransaction()) {
await session.abortTransaction();
}

if (session.inTransaction()) {
return session.abortTransaction().then(() => maybeRetryOrThrow(err));
}
if (
err instanceof MongoError &&
err.hasErrorLabel(MongoErrorLabel.TransientTransactionError) &&
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT)
) {
return await attemptTransaction(session, startTime, fn, options);
}

return maybeRetryOrThrow(err);
if (isMaxTimeMSExpiredError(err)) {
err.addErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult);
}
);
}

const endTransactionAsync = promisify(
endTransaction as (
session: ClientSession,
commandName: 'abortTransaction' | 'commitTransaction',
callback: (error: Error) => void
) => void
);
throw err;
}
}

function endTransaction(
async function endTransaction(
session: ClientSession,
commandName: 'abortTransaction' | 'commitTransaction',
callback: Callback<void>
) {
commandName: 'abortTransaction' | 'commitTransaction'
): Promise<void> {
// handle any initial problematic cases
const txnState = session.transaction.state;

if (txnState === TxnState.NO_TRANSACTION) {
callback(new MongoTransactionError('No transaction started'));
return;
throw new MongoTransactionError('No transaction started');
}

if (commandName === 'commitTransaction') {
Expand All @@ -678,37 +659,32 @@ function endTransaction(
) {
// the transaction was never started, we can safely exit here
session.transaction.transition(TxnState.TRANSACTION_COMMITTED_EMPTY);
callback();
return;
}

if (txnState === TxnState.TRANSACTION_ABORTED) {
callback(
new MongoTransactionError('Cannot call commitTransaction after calling abortTransaction')
throw new MongoTransactionError(
'Cannot call commitTransaction after calling abortTransaction'
);
return;
}
} else {
if (txnState === TxnState.STARTING_TRANSACTION) {
// the transaction was never started, we can safely exit here
session.transaction.transition(TxnState.TRANSACTION_ABORTED);
callback();
return;
}

if (txnState === TxnState.TRANSACTION_ABORTED) {
callback(new MongoTransactionError('Cannot call abortTransaction twice'));
return;
throw new MongoTransactionError('Cannot call abortTransaction twice');
}

if (
txnState === TxnState.TRANSACTION_COMMITTED ||
txnState === TxnState.TRANSACTION_COMMITTED_EMPTY
) {
callback(
new MongoTransactionError('Cannot call abortTransaction after calling commitTransaction')
throw new MongoTransactionError(
'Cannot call abortTransaction after calling commitTransaction'
);
return;
}
}

Expand Down Expand Up @@ -741,9 +717,8 @@ function endTransaction(
if (session.loadBalanced) {
maybeClearPinnedConnection(session, { force: false });
}

// The spec indicates that we should ignore all errors on `abortTransaction`
return callback();
// The spec indicates that if the operation times out or fails with a non-retryable error, we should ignore all errors on `abortTransaction`
return;
}

session.transaction.transition(TxnState.TRANSACTION_COMMITTED);
Expand All @@ -764,20 +739,36 @@ function endTransaction(
}
}

callback(error);
if (error != null) {
throw error;
}
}

if (session.transaction.recoveryToken) {
command.recoveryToken = session.transaction.recoveryToken;
}

const handleFirstCommandAttempt = (error?: Error) => {
try {
// send the command
await executeOperation(
session.client,
new RunAdminCommandOperation(command, {
session,
readPreference: ReadPreference.primary,
bypassPinningCheck: true
})
);
if (command.abortTransaction) {
// always unpin on abort regardless of command outcome
session.unpin();
}

if (error instanceof MongoError && isRetryableWriteError(error)) {
commandHandler();
} catch (firstAttemptErr) {
if (command.abortTransaction) {
// always unpin on abort regardless of command outcome
session.unpin();
}
if (firstAttemptErr instanceof MongoError && isRetryableWriteError(firstAttemptErr)) {
// SPEC-1185: apply majority write concern when retrying commitTransaction
if (command.commitTransaction) {
// per txns spec, must unpin session in this case
Expand All @@ -788,29 +779,23 @@ function endTransaction(
});
}

executeOperation(
session.client,
new RunAdminCommandOperation(command, {
session,
readPreference: ReadPreference.primary,
bypassPinningCheck: true
})
).then(() => commandHandler(), commandHandler);
return;
try {
await executeOperation(
session.client,
new RunAdminCommandOperation(command, {
session,
readPreference: ReadPreference.primary,
bypassPinningCheck: true
})
);
commandHandler();
} catch (secondAttemptErr) {
commandHandler(secondAttemptErr);
}
} else {
commandHandler(firstAttemptErr);
}

commandHandler(error);
};

// send the command
executeOperation(
session.client,
new RunAdminCommandOperation(command, {
session,
readPreference: ReadPreference.primary,
bypassPinningCheck: true
})
).then(() => handleFirstCommandAttempt(), handleFirstCommandAttempt);
}
}

/** @public */
Expand Down