Skip to content

Commit 440a62a

Browse files
committed
Add a mock perf test for Android llama2 tps
1 parent 599cfde commit 440a62a

File tree

1 file changed

+57
-0
lines changed
  • examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo

1 file changed

+57
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.example.executorchllamademo;
2+
3+
import static junit.framework.TestCase.assertTrue;
4+
import static org.junit.Assert.assertEquals;
5+
import static org.junit.Assert.assertFalse;
6+
7+
import androidx.test.ext.junit.runners.AndroidJUnit4;
8+
9+
import org.junit.Test;
10+
import org.junit.runner.RunWith;
11+
import org.pytorch.executorch.LlamaCallback;
12+
import org.pytorch.executorch.LlamaModule;
13+
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
17+
@RunWith(AndroidJUnit4.class)
18+
public class PerfTest implements LlamaCallback {
19+
20+
private static final String RESOURCE_PATH = "/data/local/tmp/llama/";
21+
private static final String MODEL_NAME = "xnnpack_llama2.pte";
22+
private static final String TOKENIZER_BIN = "tokenizer.bin";
23+
24+
// From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
25+
private static final Float EXPECTED_TPS = 7.0F;
26+
27+
private final List<String> results = new ArrayList<>();
28+
private final List<Float> tokensPerSecond = new ArrayList<>();
29+
30+
@Test
31+
public void testTokensPerSecond() {
32+
String modelPath = RESOURCE_PATH + MODEL_NAME;
33+
String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN;
34+
LlamaModule mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
35+
36+
int loadResult = mModule.load();
37+
// Check that the model can be load successfully
38+
assertEquals(0, loadResult);
39+
40+
// Run some testing prompt
41+
mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this);
42+
assertFalse(tokensPerSecond.isEmpty());
43+
44+
final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1);
45+
assertTrue(tps >= EXPECTED_TPS);
46+
}
47+
48+
@Override
49+
public void onResult(String result) {
50+
results.add(result);
51+
}
52+
53+
@Override
54+
public void onStats(float tps) {
55+
tokensPerSecond.add(tps);
56+
}
57+
}

0 commit comments

Comments
 (0)