Skip to content

Commit 04b8334

Browse files
committed
Test model download logic
1 parent a90440e commit 04b8334

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

packages/vertexai/src/methods/chrome-adapter.test.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import sinonChai from 'sinon-chai';
2020
import chaiAsPromised from 'chai-as-promised';
2121
import { ChromeAdapter } from './chrome-adapter';
2222
import { Availability, LanguageModel } from '../types/language-model';
23+
import { stub } from 'sinon';
2324

2425
use(sinonChai);
2526
use(chaiAsPromised);
@@ -56,5 +57,81 @@ describe('ChromeAdapter', () => {
5657
})
5758
).to.be.true;
5859
});
60+
it('returns false and triggers download when model is available after download', async () => {
61+
const languageModelProvider = {
62+
availability: () => Promise.resolve(Availability.downloadable),
63+
create: () => Promise.resolve({})
64+
} as LanguageModel;
65+
const createStub = stub(languageModelProvider, 'create').resolves(
66+
{} as LanguageModel
67+
);
68+
const adapter = new ChromeAdapter(
69+
languageModelProvider,
70+
'prefer_on_device'
71+
);
72+
expect(
73+
await adapter.isAvailable({
74+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
75+
})
76+
).to.be.false;
77+
expect(createStub).to.have.been.calledOnce;
78+
});
79+
it('avoids redundant downloads', async () => {
80+
const languageModelProvider = {
81+
availability: () => Promise.resolve(Availability.downloadable),
82+
create: () => Promise.resolve({})
83+
} as LanguageModel;
84+
const downloadPromise = new Promise<LanguageModel>(() => {
85+
/* never resolves */
86+
});
87+
const createStub = stub(languageModelProvider, 'create').returns(
88+
downloadPromise
89+
);
90+
const adapter = new ChromeAdapter(languageModelProvider);
91+
await adapter.isAvailable({
92+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
93+
});
94+
await adapter.isAvailable({
95+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
96+
});
97+
expect(createStub).to.have.been.calledOnce;
98+
});
99+
it('clears state when download completes', async () => {
100+
const languageModelProvider = {
101+
availability: () => Promise.resolve(Availability.downloadable),
102+
create: () => Promise.resolve({})
103+
} as LanguageModel;
104+
let resolveDownload;
105+
const downloadPromise = new Promise<LanguageModel>(resolveCallback => {
106+
resolveDownload = resolveCallback;
107+
});
108+
const createStub = stub(languageModelProvider, 'create').returns(
109+
downloadPromise
110+
);
111+
const adapter = new ChromeAdapter(languageModelProvider);
112+
await adapter.isAvailable({
113+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
114+
});
115+
resolveDownload!();
116+
await adapter.isAvailable({
117+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
118+
});
119+
expect(createStub).to.have.been.calledTwice;
120+
});
121+
it('returns false when model is never available', async () => {
122+
const languageModelProvider = {
123+
availability: () => Promise.resolve(Availability.unavailable),
124+
create: () => Promise.resolve({})
125+
} as LanguageModel;
126+
const adapter = new ChromeAdapter(
127+
languageModelProvider,
128+
'prefer_on_device'
129+
);
130+
expect(
131+
await adapter.isAvailable({
132+
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
133+
})
134+
).to.be.false;
135+
});
59136
});
60137
});

packages/vertexai/src/methods/chrome-adapter.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ import { isChrome } from '@firebase/util';
3737
* and encapsulates logic for detecting when on-device is possible.
3838
*/
3939
export class ChromeAdapter {
40-
downloadPromise: Promise<LanguageModel> | undefined;
41-
oldSession: LanguageModel | undefined;
40+
private isDownloading = false;
41+
private downloadPromise: Promise<LanguageModel | void> | undefined;
42+
private oldSession: LanguageModel | undefined;
4243
constructor(
4344
private languageModelProvider?: LanguageModel,
4445
private mode?: InferenceMode,
@@ -152,16 +153,15 @@ export class ChromeAdapter {
152153
return true;
153154
}
154155
private download(): void {
155-
if (this.downloadPromise) {
156+
if (this.isDownloading) {
156157
return;
157158
}
159+
this.isDownloading = true;
158160
this.downloadPromise = this.languageModelProvider
159161
?.create(this.onDeviceParams)
160-
.then((model: LanguageModel) => {
161-
delete this.downloadPromise;
162-
return model;
162+
.then(() => {
163+
this.isDownloading = false;
163164
});
164-
return;
165165
}
166166
private static toSystemPrompt(
167167
prompt: string | Content | Part | undefined

0 commit comments

Comments
 (0)