Skip to content

Commit 2901875

Browse files
committed
Try with any pte
1 parent c185524 commit 2901875

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

extension/android/benchmark/android-llm-device-farm-test-spec.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ phases:
8080
--es "model_dir" "/data/local/tmp/llama" \
8181
--es "tokenizer_path" "/data/local/tmp/llama/tokenizer.bin"
8282
83+
- echo "Run generic benchmark"
84+
- |
85+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
86+
--es "model_dir" "/data/local/tmp/llama"
87+
88+
8389
post_test:
8490
commands:
8591
- echo "Gather LLM benchmark results"

extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,30 @@
1111
import android.app.Activity;
1212
import android.content.Intent;
1313
import android.os.Bundle;
14+
import java.io.File;
1415
import java.io.FileWriter;
1516
import java.io.IOException;
17+
import java.util.Arrays;
1618
import org.pytorch.executorch.Module;
1719

1820
public class BenchmarkActivity extends Activity {
1921
@Override
2022
protected void onCreate(Bundle savedInstanceState) {
2123
super.onCreate(savedInstanceState);
2224
Intent intent = getIntent();
23-
String modelPath = intent.getStringExtra("model_path");
25+
File modelDir = new File(intent.getStringExtra("model_dir"));
26+
File model =
27+
Arrays.stream(modelDir.listFiles())
28+
.filter(file -> file.getName().endsWith(".pte"))
29+
.findFirst()
30+
.get();
31+
2432
int numIter = intent.getIntExtra("num_iter", 10);
2533

2634
// TODO: Format the string with a parsable format
2735
StringBuilder resultText = new StringBuilder();
2836

29-
Module module = Module.load(modelPath);
37+
Module module = Module.load(model.getPath());
3038
for (int i = 0; i < numIter; i++) {
3139
long start = System.currentTimeMillis();
3240
module.forward();

0 commit comments

Comments
 (0)