Skip to content

Commit 4108754

Browse files
PoC commit
1 parent 918fe69 commit 4108754

File tree

1 file changed

+90
-113
lines changed

1 file changed

+90
-113
lines changed

src/sessions.ts

Lines changed: 90 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import { promisify } from 'util';
2-
31
import { Binary, type Document, Long, type Timestamp } from './bson';
42
import type { CommandOptions, Connection } from './cmap/connection';
53
import { ConnectionPoolMetrics } from './cmap/metrics';
@@ -38,7 +36,6 @@ import {
3836
import {
3937
ByteUtils,
4038
calculateDurationInMs,
41-
type Callback,
4239
commandSupportsReadConcern,
4340
isPromiseLike,
4441
List,
@@ -415,14 +412,14 @@ export class ClientSession extends TypedEventEmitter<ClientSessionEvents> {
415412
* Commits the currently active transaction in this session.
416413
*/
417414
async commitTransaction(): Promise<void> {
418-
return endTransactionAsync(this, 'commitTransaction');
415+
return endTransaction(this, 'commitTransaction');
419416
}
420417

421418
/**
422419
* Aborts the currently active transaction in this session.
423420
*/
424421
async abortTransaction(): Promise<void> {
425-
return endTransactionAsync(this, 'abortTransaction');
422+
return endTransaction(this, 'abortTransaction');
426423
}
427424

428425
/**
@@ -545,33 +542,33 @@ function isMaxTimeMSExpiredError(err: MongoError) {
545542
);
546543
}
547544

548-
function attemptTransactionCommit<T>(
545+
async function attemptTransactionCommit<T>(
549546
session: ClientSession,
550547
startTime: number,
551548
fn: WithTransactionCallback<T>,
552549
result: any,
553550
options: TransactionOptions
554551
): Promise<T> {
555-
return session.commitTransaction().then(
556-
() => result,
557-
(err: MongoError) => {
558-
if (
559-
err instanceof MongoError &&
560-
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) &&
561-
!isMaxTimeMSExpiredError(err)
562-
) {
563-
if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) {
564-
return attemptTransactionCommit(session, startTime, fn, result, options);
565-
}
566-
567-
if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) {
568-
return attemptTransaction(session, startTime, fn, options);
569-
}
552+
try {
553+
await session.commitTransaction();
554+
return result;
555+
} catch (err) {
556+
if (
557+
err instanceof MongoError &&
558+
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT) &&
559+
!isMaxTimeMSExpiredError(err)
560+
) {
561+
if (err.hasErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult)) {
562+
return attemptTransactionCommit(session, startTime, fn, result, options);
570563
}
571564

572-
throw err;
565+
if (err.hasErrorLabel(MongoErrorLabel.TransientTransactionError)) {
566+
return attemptTransaction(session, startTime, fn, options);
567+
}
573568
}
574-
);
569+
570+
throw err;
571+
}
575572
}
576573

577574
const USER_EXPLICIT_TXN_END_STATES = new Set<TxnState>([
@@ -584,81 +581,64 @@ function userExplicitlyEndedTransaction(session: ClientSession) {
584581
return USER_EXPLICIT_TXN_END_STATES.has(session.transaction.state);
585582
}
586583

587-
function attemptTransaction<T>(
584+
async function attemptTransaction<T>(
588585
session: ClientSession,
589586
startTime: number,
590587
fn: WithTransactionCallback<T>,
591588
options: TransactionOptions = {}
592589
): Promise<any> {
593590
session.startTransaction(options);
591+
async function maybeRetryOrThrow(err: MongoError): Promise<any> {
592+
if (
593+
err instanceof MongoError &&
594+
err.hasErrorLabel(MongoErrorLabel.TransientTransactionError) &&
595+
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT)
596+
) {
597+
const transactionResult = await attemptTransaction(session, startTime, fn, options);
598+
return transactionResult;
599+
}
594600

595-
let promise;
596-
try {
597-
promise = fn(session);
598-
} catch (err) {
599-
promise = Promise.reject(err);
600-
}
601+
if (isMaxTimeMSExpiredError(err)) {
602+
err.addErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult);
603+
}
601604

602-
if (!isPromiseLike(promise)) {
603-
session.abortTransaction().catch(() => null);
604-
return Promise.reject(
605-
new MongoInvalidArgumentError('Function provided to `withTransaction` must return a Promise')
606-
);
605+
throw err;
607606
}
608607

609-
return promise.then(
610-
result => {
611-
if (userExplicitlyEndedTransaction(session)) {
612-
return result;
613-
}
614-
615-
return attemptTransactionCommit(session, startTime, fn, result, options);
616-
},
617-
err => {
618-
function maybeRetryOrThrow(err: MongoError): Promise<any> {
619-
if (
620-
err instanceof MongoError &&
621-
err.hasErrorLabel(MongoErrorLabel.TransientTransactionError) &&
622-
hasNotTimedOut(startTime, MAX_WITH_TRANSACTION_TIMEOUT)
623-
) {
624-
return attemptTransaction(session, startTime, fn, options);
625-
}
626-
627-
if (isMaxTimeMSExpiredError(err)) {
628-
err.addErrorLabel(MongoErrorLabel.UnknownTransactionCommitResult);
629-
}
608+
try {
609+
const promise = fn(session);
630610

631-
throw err;
632-
}
611+
if (!isPromiseLike(promise)) {
612+
await session.abortTransaction();
613+
return new MongoInvalidArgumentError(
614+
'Function provided to `withTransaction` must return a Promise'
615+
);
616+
}
617+
const result = await promise;
633618

634-
if (session.inTransaction()) {
635-
return session.abortTransaction().then(() => maybeRetryOrThrow(err));
636-
}
619+
if (userExplicitlyEndedTransaction(session)) {
620+
return result;
621+
}
637622

638-
return maybeRetryOrThrow(err);
623+
return await attemptTransactionCommit(session, startTime, fn, result, options);
624+
} catch (err) {
625+
if (session.inTransaction()) {
626+
await session.abortTransaction();
639627
}
640-
);
641-
}
642628

643-
const endTransactionAsync = promisify(
644-
endTransaction as (
645-
session: ClientSession,
646-
commandName: 'abortTransaction' | 'commitTransaction',
647-
callback: (error: Error) => void
648-
) => void
649-
);
629+
return await maybeRetryOrThrow(err);
630+
}
631+
}
650632

651-
function endTransaction(
633+
async function endTransaction(
652634
session: ClientSession,
653-
commandName: 'abortTransaction' | 'commitTransaction',
654-
callback: Callback<void>
635+
commandName: 'abortTransaction' | 'commitTransaction'
655636
) {
656637
// handle any initial problematic cases
657638
const txnState = session.transaction.state;
658639

659640
if (txnState === TxnState.NO_TRANSACTION) {
660-
callback(new MongoTransactionError('No transaction started'));
661-
return;
641+
throw new MongoTransactionError('No transaction started');
662642
}
663643

664644
if (commandName === 'commitTransaction') {
@@ -668,37 +648,31 @@ function endTransaction(
668648
) {
669649
// the transaction was never started, we can safely exit here
670650
session.transaction.transition(TxnState.TRANSACTION_COMMITTED_EMPTY);
671-
callback();
672651
return;
673652
}
674653

675654
if (txnState === TxnState.TRANSACTION_ABORTED) {
676-
callback(
677-
new MongoTransactionError('Cannot call commitTransaction after calling abortTransaction')
655+
throw new MongoTransactionError(
656+
'Cannot call commitTransaction after calling abortTransaction'
678657
);
679-
return;
680658
}
681659
} else {
682660
if (txnState === TxnState.STARTING_TRANSACTION) {
683661
// the transaction was never started, we can safely exit here
684662
session.transaction.transition(TxnState.TRANSACTION_ABORTED);
685-
callback();
686-
return;
687663
}
688664

689665
if (txnState === TxnState.TRANSACTION_ABORTED) {
690-
callback(new MongoTransactionError('Cannot call abortTransaction twice'));
691-
return;
666+
throw new MongoTransactionError('Cannot call abortTransaction twice');
692667
}
693668

694669
if (
695670
txnState === TxnState.TRANSACTION_COMMITTED ||
696671
txnState === TxnState.TRANSACTION_COMMITTED_EMPTY
697672
) {
698-
callback(
699-
new MongoTransactionError('Cannot call abortTransaction after calling commitTransaction')
673+
throw new MongoTransactionError(
674+
'Cannot call abortTransaction after calling commitTransaction'
700675
);
701-
return;
702676
}
703677
}
704678

@@ -731,9 +705,6 @@ function endTransaction(
731705
if (session.loadBalanced) {
732706
maybeClearPinnedConnection(session, { force: false });
733707
}
734-
735-
// The spec indicates that we should ignore all errors on `abortTransaction`
736-
return callback();
737708
}
738709

739710
session.transaction.transition(TxnState.TRANSACTION_COMMITTED);
@@ -753,15 +724,13 @@ function endTransaction(
753724
session.unpin({ error });
754725
}
755726
}
756-
757-
callback(error);
758727
}
759728

760729
if (session.transaction.recoveryToken) {
761730
command.recoveryToken = session.transaction.recoveryToken;
762731
}
763732

764-
const handleFirstCommandAttempt = (error?: Error) => {
733+
const handleFirstCommandAttempt = async (error?: Error) => {
765734
if (command.abortTransaction) {
766735
// always unpin on abort regardless of command outcome
767736
session.unpin();
@@ -778,29 +747,37 @@ function endTransaction(
778747
});
779748
}
780749

781-
executeOperation(
782-
session.client,
783-
new RunAdminCommandOperation(command, {
784-
session,
785-
readPreference: ReadPreference.primary,
786-
bypassPinningCheck: true
787-
})
788-
).then(() => commandHandler(), commandHandler);
789-
return;
750+
try {
751+
await executeOperation(
752+
session.client,
753+
new RunAdminCommandOperation(command, {
754+
session,
755+
readPreference: ReadPreference.primary,
756+
bypassPinningCheck: true
757+
})
758+
);
759+
commandHandler();
760+
} catch (err) {
761+
commandHandler(err);
762+
throw err;
763+
}
790764
}
791-
792-
commandHandler(error);
793765
};
794766

795-
// send the command
796-
executeOperation(
797-
session.client,
798-
new RunAdminCommandOperation(command, {
799-
session,
800-
readPreference: ReadPreference.primary,
801-
bypassPinningCheck: true
802-
})
803-
).then(() => handleFirstCommandAttempt(), handleFirstCommandAttempt);
767+
try {
768+
// send the command
769+
await executeOperation(
770+
session.client,
771+
new RunAdminCommandOperation(command, {
772+
session,
773+
readPreference: ReadPreference.primary,
774+
bypassPinningCheck: true
775+
})
776+
);
777+
await handleFirstCommandAttempt();
778+
} catch (err) {
779+
await handleFirstCommandAttempt(err);
780+
}
804781
}
805782

806783
/** @public */

0 commit comments

Comments
 (0)