Skip to content

Commit 40720f0

Browse files
authored
Use Android llm benchmark runner
Differential Revision: D62279317 Pull Request resolved: #5094
1 parent 0458c2e commit 40720f0

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

.github/workflows/upload-android-test-specs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
with:
4242
# Just use a small model here with a minimal amount of configuration to test the spec
4343
models: stories110M
44-
devices: samsung_galaxy_s2x
44+
devices: samsung_galaxy_s22
4545
delegates: xnnpack
4646
test_spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/android-llm-device-farm-test-spec.yml
4747

examples/demo-apps/android/LlamaDemo/android-llm-device-farm-test-spec.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,30 @@ phases:
7373
fi
7474
fi;
7575
76+
# Run the new generic benchmark activity https://developer.android.com/tools/adb#am
77+
- echo "Run LLM benchmark"
78+
- |
79+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n com.example.executorchllamademo/.LlmBenchmarkRunner \
80+
--es "model_dir" "/data/local/tmp/llama" \
81+
--es "tokenizer_path" "/data/local/tmp/llama/tokenizer.bin"
82+
7683
post_test:
7784
commands:
85+
- echo "Gather LLM benchmark results"
86+
- |
87+
BENCHMARK_RESULTS=""
88+
ATTEMPT=0
89+
MAX_ATTEMPT=10
90+
while [ -z "${BENCHMARK_RESULTS}" ] && [ $ATTEMPT -lt $MAX_ATTEMPT ]; do
91+
echo "Waiting for benchmark results..."
92+
BENCHMARK_RESULTS=$(adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo cat files/benchmark_results.json)
93+
sleep 30
94+
((ATTEMPT++))
95+
done
96+
97+
adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo ls -la files/
98+
# Trying to pull the file using adb ends up with permission error, but this works too, so why not
99+
echo "${BENCHMARK_RESULTS}" > $DEVICEFARM_LOG_DIR/benchmark_results.json
78100
79101
artifacts:
80102
# By default, Device Farm will collect your artifacts from the $DEVICEFARM_LOG_DIR directory.

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/LlmBenchmarkRunner.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
import android.util.Log;
1515
import android.widget.TextView;
1616
import androidx.annotation.NonNull;
17+
import com.google.gson.Gson;
18+
import java.io.File;
1719
import java.io.FileWriter;
1820
import java.io.IOException;
21+
import java.util.Arrays;
1922

2023
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
2124
ModelRunner mModelRunner;
@@ -32,7 +35,12 @@ protected void onCreate(Bundle savedInstanceState) {
3235

3336
Intent intent = getIntent();
3437

35-
String modelPath = intent.getStringExtra("model_path");
38+
File modelDir = new File(intent.getStringExtra("model_dir"));
39+
File model =
40+
Arrays.stream(modelDir.listFiles())
41+
.filter(file -> file.getName().endsWith(".pte"))
42+
.findFirst()
43+
.get();
3644
String tokenizerPath = intent.getStringExtra("tokenizer_path");
3745

3846
float temperature = intent.getFloatExtra("temperature", 0.8f);
@@ -42,7 +50,7 @@ protected void onCreate(Bundle savedInstanceState) {
4250
}
4351

4452
mStatsDump = new StatsDump();
45-
mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this);
53+
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
4654
mStatsDump.loadStart = System.currentTimeMillis();
4755
}
4856

@@ -79,11 +87,21 @@ public void onGenerationStopped() {
7987
mTextView.append(mStatsDump.toString());
8088
});
8189

90+
// TODO (huydhn): Remove txt files here once the JSON format is ready
8291
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
8392
writer.write(mStatsDump.toString());
8493
} catch (IOException e) {
8594
e.printStackTrace();
8695
}
96+
97+
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
98+
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
99+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
100+
Gson gson = new Gson();
101+
writer.write(gson.toJson(mStatsDump));
102+
} catch (IOException e) {
103+
e.printStackTrace();
104+
}
87105
}
88106
}
89107

0 commit comments

Comments
 (0)