Skip to content

chore(property-provider): refactor memoize to use arrow functions #1281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/credential-provider-node/src/index.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ describe("defaultProvider", () => {

expect(await provider()).toEqual(creds);

expect(provider()).toBe(provider());
expect(provider()).toStrictEqual(provider());

expect(await provider()).toEqual(creds);
expect((fromEnv() as any).mock.calls.length).toBe(1);
Expand Down
153 changes: 100 additions & 53 deletions packages/property-provider/src/memoize.spec.ts
Original file line number Diff line number Diff line change
@@ -1,79 +1,126 @@
import { memoize } from "./memoize";
import { Provider } from "@aws-sdk/types";

describe("memoize", () => {
let provider: jest.Mock;
const mockReturn = "foo";
const repeatTimes = 10;

beforeEach(() => {
provider = jest.fn().mockResolvedValue(mockReturn);
});

afterEach(() => {
jest.clearAllMocks();
});

describe("static memoization", () => {
it("should cache the resolved provider", async () => {
const provider = jest.fn().mockResolvedValue("foo");
expect.assertions(repeatTimes * 2);

const memoized = memoize(provider);

expect(await memoized()).toEqual("foo");
expect(provider.mock.calls.length).toBe(1);
expect(await memoized()).toEqual("foo");
expect(provider.mock.calls.length).toBe(1);
for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toStrictEqual(mockReturn);
expect(provider).toHaveBeenCalledTimes(1);
}
});

it("should always return the same promise", () => {
const provider = jest.fn().mockResolvedValue("foo");
expect.assertions(repeatTimes * 2);

const memoized = memoize(provider);
const result = memoized();

expect(memoized()).toBe(result);
for (const index in [...Array(repeatTimes).keys()]) {
expect(memoized()).toStrictEqual(result);
expect(provider).toHaveBeenCalledTimes(1);
}
});
});

describe("refreshing memoization", () => {
it("should not reinvoke the underlying provider while isExpired returns `false`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(false);
const memoized = memoize(provider, isExpired);

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}
let isExpired: jest.Mock;
let requiresRefresh: jest.Mock;

expect(isExpired.mock.calls.length).toBe(checkCount);
expect(provider.mock.calls.length).toBe(1);
beforeEach(() => {
isExpired = jest.fn().mockReturnValue(true);
requiresRefresh = jest.fn().mockReturnValue(false);
});

it("should reinvoke the underlying provider when isExpired returns `true`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(false);
const memoized = memoize(provider, isExpired);

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}

expect(isExpired.mock.calls.length).toBe(checkCount);
expect(provider.mock.calls.length).toBe(1);

isExpired.mockReturnValueOnce(true);
for (let i = 0; i < checkCount; i++) {
expect(await memoized()).toBe("foo");
}

expect(isExpired.mock.calls.length).toBe(checkCount * 2);
expect(provider.mock.calls.length).toBe(2);
describe("should not reinvoke the underlying provider while isExpired returns `false`", () => {
const isExpiredFalseTest = async (requiresRefresh?: any) => {
isExpired.mockReturnValue(false);
const memoized = memoize(provider, isExpired, requiresRefresh);

for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toEqual(mockReturn);
}

expect(isExpired).toHaveBeenCalledTimes(repeatTimes);
if (requiresRefresh) {
expect(requiresRefresh).toHaveBeenCalledTimes(repeatTimes);
}
expect(provider).toHaveBeenCalledTimes(1);
};

it("when requiresRefresh is not passed", async () => {
return isExpiredFalseTest();
});

it("when requiresRefresh returns true", () => {
requiresRefresh.mockReturnValue(true);
return isExpiredFalseTest(requiresRefresh);
});
});

it("should return the same promise for invocations 2-infinity if `requiresRefresh` returns `false`", async () => {
const provider = jest.fn().mockResolvedValue("foo");
const isExpired = jest.fn().mockReturnValue(true);
const requiresRefresh = jest.fn().mockReturnValue(false);

const memoized = memoize(provider, isExpired, requiresRefresh);
expect(await memoized()).toBe("foo");
const set = new Set<Promise<string>>();

const checkCount = 10;
for (let i = 0; i < checkCount; i++) {
set.add(memoized());
}
describe("should reinvoke the underlying provider when isExpired returns `true`", () => {
const isExpiredTrueTest = async (requiresRefresh?: any) => {
const memoized = memoize(provider, isExpired, requiresRefresh);

for (const index in [...Array(repeatTimes).keys()]) {
expect(await memoized()).toEqual(mockReturn);
}

expect(isExpired).toHaveBeenCalledTimes(repeatTimes);
if (requiresRefresh) {
expect(requiresRefresh).toHaveBeenCalledTimes(repeatTimes);
}
expect(provider).toHaveBeenCalledTimes(repeatTimes + 1);
};

it("when requiresRefresh is not passed", () => {
return isExpiredTrueTest();
});

it("when requiresRefresh returns true", () => {
requiresRefresh.mockReturnValue(true);
return isExpiredTrueTest(requiresRefresh);
});
});

expect(set.size).toBe(1);
describe("should return the same promise for invocations 2-infinity if `requiresRefresh` returns `false`", () => {
const requiresRefreshFalseTest = async () => {
const memoized = memoize(provider, isExpired, requiresRefresh);
const result = memoized();
expect(await result).toBe(mockReturn);

for (const index in [...Array(repeatTimes).keys()]) {
expect(memoized()).toStrictEqual(result);
expect(provider).toHaveBeenCalledTimes(1);
}

expect(requiresRefresh).toHaveBeenCalledTimes(1);
expect(isExpired).not.toHaveBeenCalled();
};

it("when isExpired returns true", () => {
return requiresRefreshFalseTest();
});

it("when isExpired returns false", () => {
isExpired.mockReturnValue(false);
return requiresRefreshFalseTest();
});
});
});
});
101 changes: 50 additions & 51 deletions packages/property-provider/src/memoize.ts
Original file line number Diff line number Diff line change
@@ -1,48 +1,50 @@
import { Provider } from "@aws-sdk/types";

/**
*
* Decorates a provider function with either static memoization.
*
* To create a statically memoized provider, supply a provider as the only
* argument to this function. The provider will be invoked once, and all
* invocations of the provider returned by `memoize` will return the same
* promise object.
*
* @param provider The provider whose result should be cached indefinitely.
*/
export function memoize<T>(provider: Provider<T>): Provider<T>;
interface MemoizeOverload {
/**
*
* Decorates a provider function with either static memoization.
*
* To create a statically memoized provider, supply a provider as the only
* argument to this function. The provider will be invoked once, and all
* invocations of the provider returned by `memoize` will return the same
* promise object.
*
* @param provider The provider whose result should be cached indefinitely.
*/
<T>(provider: Provider<T>): Provider<T>;

/**
* Decorates a provider function with refreshing memoization.
*
* @param provider The provider whose result should be cached.
* @param isExpired A function that will evaluate the resolved value and
* determine if it is expired. For example, when
* memoizing AWS credential providers, this function
* should return `true` when the credential's
* expiration is in the past (or very near future) and
* `false` otherwise.
* @param requiresRefresh A function that will evaluate the resolved value and
* determine if it represents static value or one that
* will eventually need to be refreshed. For example,
* AWS credentials that have no defined expiration will
* never need to be refreshed, so this function would
* return `true` if the credentials resolved by the
* underlying provider had an expiration and `false`
* otherwise.
*/
export function memoize<T>(
provider: Provider<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T>;
/**
* Decorates a provider function with refreshing memoization.
*
* @param provider The provider whose result should be cached.
* @param isExpired A function that will evaluate the resolved value and
* determine if it is expired. For example, when
* memoizing AWS credential providers, this function
* should return `true` when the credential's
* expiration is in the past (or very near future) and
* `false` otherwise.
* @param requiresRefresh A function that will evaluate the resolved value and
* determine if it represents static value or one that
* will eventually need to be refreshed. For example,
* AWS credentials that have no defined expiration will
* never need to be refreshed, so this function would
* return `true` if the credentials resolved by the
* underlying provider had an expiration and `false`
* otherwise.
*/
<T>(
provider: Provider<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T>;
}

export function memoize<T>(
export const memoize: MemoizeOverload = <T>(
provider: Provider<T>,
isExpired?: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean
): Provider<T> {
): Provider<T> => {
if (isExpired === undefined) {
// This is a static memoization; no need to incorporate refreshing
const result = provider();
Expand All @@ -52,22 +54,19 @@ export function memoize<T>(
let result = provider();
let isConstant: boolean = false;

return () => {
return async () => {
if (isConstant) {
return result;
}

return result.then(resolved => {
if (requiresRefresh && !requiresRefresh(resolved)) {
isConstant = true;
return resolved;
}

if (isExpired(resolved)) {
return (result = provider());
}

const resolved = await result;
if (requiresRefresh && !requiresRefresh(resolved)) {
isConstant = true;
return resolved;
});
}
if (isExpired(resolved)) {
return (result = provider());
}
return resolved;
};
}
};