Skip to content

Commit d761f99

Browse files
huydhnfacebook-github-bot
authored andcommitted
Add a mock perf test for llama2 on Android (#2963)
Summary: I'm trying to setup a simple perf test when running llama2 on Android. It's naively sent a prompt and record the TPS. Open for comment about the test here before setting this up on CI. ### Testing Copy the exported model and the tokenizer as usual, then cd to the app and run `./gradlew :app:connectAndroidTest`. The test will fail if the model is failed to load or if the TPS is lower than 7 as measure by https://github.com/pytorch/executorch/tree/main/examples/models/llama2 Pull Request resolved: #2963 Reviewed By: kirklandsign Differential Revision: D55951637 Pulled By: huydhn fbshipit-source-id: 34c189aefd7e31514fcf49103352ef3cf8e5b2c9
1 parent 2fc99b0 commit d761f99

File tree

1 file changed

+65
-0
lines changed
  • examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo

1 file changed

+65
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
import static junit.framework.TestCase.assertTrue;
12+
import static org.junit.Assert.assertEquals;
13+
import static org.junit.Assert.assertFalse;
14+
15+
import androidx.test.ext.junit.runners.AndroidJUnit4;
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
import org.junit.Test;
19+
import org.junit.runner.RunWith;
20+
import org.pytorch.executorch.LlamaCallback;
21+
import org.pytorch.executorch.LlamaModule;
22+
23+
@RunWith(AndroidJUnit4.class)
24+
public class PerfTest implements LlamaCallback {
25+
26+
private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
27+
private static final String MODEL_NAME = "xnnpack_llama2.pte";
28+
private static final String TOKENIZER_BIN = "tokenizer.bin";
29+
30+
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
31+
private static final Float EXPECTED_TPS = 7.0F;
32+
33+
private final List<String> results = new ArrayList<>();
34+
private final List<Float> tokensPerSecond = new ArrayList<>();
35+
36+
@Test
37+
public void testTokensPerSecond() {
38+
String modelPath = RESOURCE_PATH + MODEL_NAME;
39+
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
40+
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
41+
42+
int loadResult = mModule.load();
43+
// Check that the model can be load successfully
44+
assertEquals(0, loadResult);
45+
46+
// Run a testing prompt
47+
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
48+
assertFalse(tokensPerSecond.isEmpty());
49+
50+
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
51+
assertTrue(
52+
"The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS,
53+
tps >= EXPECTED_TPS);
54+
}
55+
56+
@Override
57+
public void onResult(String result) {
58+
results.add(result);
59+
}
60+
61+
@Override
62+
public void onStats(float tps) {
63+
tokensPerSecond.add(tps);
64+
}
65+
}

0 commit comments

Comments
 (0)