Skip to content

[Auth] Add onAbort and refactor middleware #6178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion common/api-review/auth.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export function applyActionCode(auth: Auth, oobCode: string): Promise<void>;
// @public
export interface Auth {
readonly app: FirebaseApp;
beforeAuthStateChanged(callback: (user: User | null) => void | Promise<void>): Unsubscribe;
beforeAuthStateChanged(callback: (user: User | null) => void | Promise<void>, onAbort?: () => void): Unsubscribe;
readonly config: Config;
readonly currentUser: User | null;
readonly emulatorConfig: EmulatorConfig | null;
Expand Down Expand Up @@ -242,6 +242,9 @@ export interface AuthSettings {
appVerificationDisabledForTesting: boolean;
}

// @public
export function beforeAuthStateChanged(auth: Auth, callback: (user: User | null) => void | Promise<void>, onAbort?: () => void): Unsubscribe;

// @public
export const browserLocalPersistence: Persistence;

Expand Down
48 changes: 8 additions & 40 deletions packages/auth/src/core/auth/auth_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import { _getInstance } from '../util/instantiator';
import { _getUserLanguage } from '../util/navigator';
import { _getClientVersion } from '../util/version';
import { HttpHeader } from '../../api';
import { AuthMiddlewareQueue } from './middleware';

interface AsyncAction {
(): Promise<void>;
Expand All @@ -79,7 +80,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService {
private redirectPersistenceManager?: PersistenceUserManager;
private authStateSubscription = new Subscription<User>(this);
private idTokenSubscription = new Subscription<User>(this);
private beforeStateQueue: Array<(user: User | null) => Promise<void>> = [];
private readonly beforeStateQueue = new AuthMiddlewareQueue(this);
private redirectUser: UserInternal | null = null;
private isProactiveRefreshEnabled = false;

Expand Down Expand Up @@ -225,7 +226,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService {
// First though, ensure that we check the middleware is happy.
if (needsTocheckMiddleware) {
try {
await this._runBeforeStateCallbacks(futureCurrentUser);
await this.beforeStateQueue.runMiddleware(futureCurrentUser);
} catch(e) {
futureCurrentUser = previouslyStoredUser;
// We know this is available since the bit is only set when the
Expand Down Expand Up @@ -347,7 +348,7 @@ export class AuthImpl implements AuthInternal, _FirebaseService {
}

if (!skipBeforeStateCallbacks) {
await this._runBeforeStateCallbacks(user);
await this.beforeStateQueue.runMiddleware(user);
}

return this.queue(async () => {
Expand All @@ -356,23 +357,9 @@ export class AuthImpl implements AuthInternal, _FirebaseService {
});
}

async _runBeforeStateCallbacks(user: User | null): Promise<void> {
if (this.currentUser === user) {
return;
}
try {
for (const beforeStateCallback of this.beforeStateQueue) {
await beforeStateCallback(user);
}
} catch (e) {
throw this._errorFactory.create(
AuthErrorCode.LOGIN_BLOCKED, { originalMessage: e.message });
}
}

async signOut(): Promise<void> {
// Run first, to block _setRedirectUser() if any callbacks fail.
await this._runBeforeStateCallbacks(null);
await this.beforeStateQueue.runMiddleware(null);
// Clear the redirect user when signOut is called
if (this.redirectPersistenceManager || this._popupRedirectResolver) {
await this._setRedirectUser(null);
Expand Down Expand Up @@ -415,29 +402,10 @@ export class AuthImpl implements AuthInternal, _FirebaseService {
}

beforeAuthStateChanged(
callback: (user: User | null) => void | Promise<void>
callback: (user: User | null) => void | Promise<void>,
onAbort?: () => void,
): Unsubscribe {
// The callback could be sync or async. Wrap it into a
// function that is always async.
const wrappedCallback =
(user: User | null): Promise<void> => new Promise((resolve, reject) => {
try {
const result = callback(user);
// Either resolve with existing promise or wrap a non-promise
// return value into a promise.
resolve(result);
} catch (e) {
// Sync callback throws.
reject(e);
}
});
this.beforeStateQueue.push(wrappedCallback);
const index = this.beforeStateQueue.length - 1;
return () => {
// Unsubscribe. Replace with no-op. Do not remove from array, or it will disturb
// indexing of other elements.
this.beforeStateQueue[index] = () => Promise.resolve();
};
return this.beforeStateQueue.pushCallback(callback, onAbort);
}

onIdTokenChanged(
Expand Down
132 changes: 132 additions & 0 deletions packages/auth/src/core/auth/middleware.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import { expect, use } from 'chai';
import chaiAsPromised from 'chai-as-promised';
import * as sinon from 'sinon';
import sinonChai from 'sinon-chai';
import { testAuth, testUser } from '../../../test/helpers/mock_auth';
import { AuthInternal } from '../../model/auth';
import { User } from '../../model/public_types';
import { AuthMiddlewareQueue } from './middleware';

use(chaiAsPromised);
use(sinonChai);

describe('Auth middleware', () => {
let middlewareQueue: AuthMiddlewareQueue;
let user: User;
let auth: AuthInternal;

beforeEach(async () => {
auth = await testAuth();
user = testUser(auth, 'uid');
middlewareQueue = new AuthMiddlewareQueue(auth);
});

afterEach(() => {
sinon.restore();
});

it('calls middleware in order', async () => {
const calls: number[] = [];

middlewareQueue.pushCallback(() => {calls.push(1);});
middlewareQueue.pushCallback(() => {calls.push(2);});
middlewareQueue.pushCallback(() => {calls.push(3);});

await middlewareQueue.runMiddleware(user);

expect(calls).to.eql([1, 2, 3]);
});

it('rejects on error', async () => {
middlewareQueue.pushCallback(() => {
throw new Error('no');
});
await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked');
});

it('rejects on promise rejection', async () => {
middlewareQueue.pushCallback(() => Promise.reject('no'));
await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked');
});

it('awaits middleware completion before calling next', async () => {
const firstCb = sinon.spy();
const secondCb = sinon.spy();

middlewareQueue.pushCallback(() => {
// Force the first one to run one tick later
return new Promise(resolve => {
setTimeout(() => {
firstCb();
resolve();
}, 1);
});
});
middlewareQueue.pushCallback(secondCb);

await middlewareQueue.runMiddleware(user);
expect(secondCb).to.have.been.calledAfter(firstCb);
});

it('subsequent middleware not run after rejection', async () => {
const spy = sinon.spy();

middlewareQueue.pushCallback(() => {
throw new Error('no');
});
middlewareQueue.pushCallback(spy);

await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked');
expect(spy).not.to.have.been.called;
});

it('calls onAbort if provided but only for earlier runs', async () => {
const firstOnAbort = sinon.spy();
const secondOnAbort = sinon.spy();

middlewareQueue.pushCallback(() => {}, firstOnAbort);
middlewareQueue.pushCallback(() => {
throw new Error('no');
}, secondOnAbort);

await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked');
expect(firstOnAbort).to.have.been.called;
expect(secondOnAbort).not.to.have.been.called;
});

it('calls onAbort in reverse order', async () => {
const calls: number[] = [];

middlewareQueue.pushCallback(() => {}, () => {calls.push(1);});
middlewareQueue.pushCallback(() => {}, () => {calls.push(2);});
middlewareQueue.pushCallback(() => {}, () => {calls.push(3);});
middlewareQueue.pushCallback(() => {
throw new Error('no');
});

await expect(middlewareQueue.runMiddleware(user)).to.be.rejectedWith('auth/login-blocked');
expect(calls).to.eql([3, 2, 1]);
});

it('does not call any middleware if user matches null', async () => {
const spy = sinon.spy();

middlewareQueue.pushCallback(spy);
await middlewareQueue.runMiddleware(null);

expect(spy).not.to.have.been.called;
});

it('does not call any middleware if user matches object', async () => {
const spy = sinon.spy();

// Directly set it manually since the public function creates a
// copy of the user.
auth.currentUser = user;

middlewareQueue.pushCallback(spy);
await middlewareQueue.runMiddleware(user);

expect(spy).not.to.have.been.called;
});
});
76 changes: 76 additions & 0 deletions packages/auth/src/core/auth/middleware.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import { AuthInternal } from '../../model/auth';
import { Unsubscribe, User } from '../../model/public_types';
import { AuthErrorCode } from '../errors';

interface MiddlewareEntry {
(user: User | null): Promise<void>;
onAbort?: () => void;
}

export class AuthMiddlewareQueue {
private readonly queue: MiddlewareEntry[] = [];

constructor(private readonly auth: AuthInternal) {}

pushCallback(
callback: (user: User | null) => void | Promise<void>,
onAbort?: () => void): Unsubscribe {
// The callback could be sync or async. Wrap it into a
// function that is always async.
const wrappedCallback: MiddlewareEntry =
(user: User | null): Promise<void> => new Promise((resolve, reject) => {
try {
const result = callback(user);
// Either resolve with existing promise or wrap a non-promise
// return value into a promise.
resolve(result);
} catch (e) {
// Sync callback throws.
reject(e);
}
});
// Attach the onAbort if present
wrappedCallback.onAbort = onAbort;
this.queue.push(wrappedCallback);

const index = this.queue.length - 1;
return () => {
// Unsubscribe. Replace with no-op. Do not remove from array, or it will disturb
// indexing of other elements.
this.queue[index] = () => Promise.resolve();
};
}

async runMiddleware(nextUser: User | null): Promise<void> {
if (this.auth.currentUser === nextUser) {
return;
}

// While running the middleware, build a temporary stack of onAbort
// callbacks to call if one middleware callback rejects.

const onAbortStack: Array<() => void> = [];
try {
for (const beforeStateCallback of this.queue) {
await beforeStateCallback(nextUser);

// Only push the onAbort if the callback succeeds
if (beforeStateCallback.onAbort) {
onAbortStack.push(beforeStateCallback.onAbort);
}
}
} catch (e) {
// Run all onAbort, with separate try/catch to ignore any errors and
// continue
onAbortStack.reverse();
for (const onAbort of onAbortStack) {
try {
onAbort();
} catch (_) { /* swallow error */}
}

throw this.auth._errorFactory.create(
AuthErrorCode.LOGIN_BLOCKED, { originalMessage: e.message });
}
}
}
20 changes: 20 additions & 0 deletions packages/auth/src/core/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ export function onIdTokenChanged(
completed
);
}
/**
* Adds a blocking callback that runs before an auth state change
* sets a new user.
*
* @param auth - The {@link Auth} instance.
* @param callback - callback triggered before new user value is set.
* If this throws, it blocks the user from being set.
* @param onAbort - callback triggered if a later `beforeAuthStateChanged()`
* callback throws, allowing you to undo any side effects.
*/
export function beforeAuthStateChanged(
auth: Auth,
callback: (user: User|null) => void | Promise<void>,
onAbort?: () => void,
): Unsubscribe {
return getModularInstance(auth).beforeAuthStateChanged(
callback,
onAbort
);
}
/**
* Adds an observer for changes to the user's sign-in state.
*
Expand Down
5 changes: 4 additions & 1 deletion packages/auth/src/model/public_types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,12 @@ export interface Auth {
*
* @param callback - callback triggered before new user value is set.
* If this throws, it blocks the user from being set.
* @param onAbort - callback triggered if a later `beforeAuthStateChanged()`
* callback throws, allowing you to undo any side effects.
*/
beforeAuthStateChanged(
callback: (user: User | null) => void | Promise<void>
callback: (user: User | null) => void | Promise<void>,
onAbort?: () => void,
): Unsubscribe;
/**
* Adds an observer for changes to the signed-in user's ID token.
Expand Down
17 changes: 15 additions & 2 deletions packages/auth/test/integration/flows/middleware_test_generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ export function generateMiddlewareTests(authGetter: () => Auth, signIn: () => Pr
* automatically unsubscribe after every test (since some tests may
* perform cleanup after that would be affected by the middleware)
*/
function beforeAuthStateChanged(callback: (user: User | null) => void | Promise<void>): void {
unsubscribes.push(auth.beforeAuthStateChanged(callback));
function beforeAuthStateChanged(callback: (user: User | null) => void | Promise<void>, onAbort?: () => void): void {
unsubscribes.push(auth.beforeAuthStateChanged(callback, onAbort));
}

it('can prevent user sign in', async () => {
Expand Down Expand Up @@ -192,5 +192,18 @@ export function generateMiddlewareTests(authGetter: () => Auth, signIn: () => Pr
await expect(auth.signOut()).to.be.rejectedWith('auth/login-blocked');
expect(auth.currentUser).to.eq(user);
});

it('calls onAbort after rejection', async () => {
const onAbort = sinon.spy();
beforeAuthStateChanged(() => {
// Pass
}, onAbort);
beforeAuthStateChanged(() => {
throw new Error('block sign out');
});

await expect(signIn()).to.be.rejectedWith('auth/login-blocked');
expect(onAbort).to.have.been.called;
});
});
}