Skip to content

Use Android llm benchmark runner #5094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 6, 2024
2 changes: 1 addition & 1 deletion .github/workflows/upload-android-test-specs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
with:
# Just use a small model here with a minimal amount of configuration to test the spec
models: stories110M
devices: samsung_galaxy_s2x
devices: samsung_galaxy_s22
delegates: xnnpack
test_spec: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/android-llm-device-farm-test-spec.yml

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,30 @@ phases:
fi
fi;

# Run the new generic benchmark activity https://developer.android.com/tools/adb#am
- echo "Run LLM benchmark"
- |
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n com.example.executorchllamademo/.LlmBenchmarkRunner \
--es "model_dir" "/data/local/tmp/llama" \
--es "tokenizer_path" "/data/local/tmp/llama/tokenizer.bin"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Exactly what i want to do :D

Need to find a way to wait for the file to result file to appear. shell am start is async

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, got it. TIL. I'm still working on this to make the command works, so stay tuned :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this kind of stuff work? 🤔
adb shell while [ ! -f /data/local/tmp/result.txt ]; do sleep 1; done

Copy link
Contributor Author

@huydhn huydhn Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adb shell doesn't like the way I write bash script, so I need to look for a work around by cat-ing the results. It works nonetheless, so I guess we are good :)


post_test:
commands:
- echo "Gather LLM benchmark results"
- |
BENCHMARK_RESULTS=""
ATTEMPT=0
MAX_ATTEMPT=10
while [ -z "${BENCHMARK_RESULTS}" ] && [ $ATTEMPT -lt $MAX_ATTEMPT ]; do
echo "Waiting for benchmark results..."
BENCHMARK_RESULTS=$(adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo cat files/benchmark_results.json)
sleep 30
((ATTEMPT++))
done

adb -s $DEVICEFARM_DEVICE_UDID shell run-as com.example.executorchllamademo ls -la files/
# Trying to pull the file using adb ends up with permission error, but this works too, so why not
echo "${BENCHMARK_RESULTS}" > $DEVICEFARM_LOG_DIR/benchmark_results.json

artifacts:
# By default, Device Farm will collect your artifacts from the $DEVICEFARM_LOG_DIR directory.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
import android.util.Log;
import android.widget.TextView;
import androidx.annotation.NonNull;
import com.google.gson.Gson;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;

public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand All @@ -32,7 +35,12 @@ protected void onCreate(Bundle savedInstanceState) {

Intent intent = getIntent();

String modelPath = intent.getStringExtra("model_path");
File modelDir = new File(intent.getStringExtra("model_dir"));
File model =
Arrays.stream(modelDir.listFiles())
.filter(file -> file.getName().endsWith(".pte"))
.findFirst()
.get();
String tokenizerPath = intent.getStringExtra("tokenizer_path");

float temperature = intent.getFloatExtra("temperature", 0.8f);
Expand All @@ -42,7 +50,7 @@ protected void onCreate(Bundle savedInstanceState) {
}

mStatsDump = new StatsDump();
mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this);
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsDump.loadStart = System.currentTimeMillis();
}

Expand Down Expand Up @@ -79,11 +87,21 @@ public void onGenerationStopped() {
mTextView.append(mStatsDump.toString());
});

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsDump.toString());
} catch (IOException e) {
e.printStackTrace();
}

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsDump));
} catch (IOException e) {
e.printStackTrace();
}
}
}

Expand Down
Loading