Skip to content

Define generic Android benchmark metric structure #5332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
package com.example.executorchllamademo;

import android.app.Activity;
import android.app.ActivityManager;
import android.content.Intent;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
Expand All @@ -18,7 +20,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand Down Expand Up @@ -50,19 +56,21 @@ protected void onCreate(Bundle savedInstanceState) {
}

mStatsDump = new StatsDump();
mStatsDump.modelName = model.getName().replace(".pte", "");
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsDump.loadStart = System.currentTimeMillis();
mStatsDump.loadStart = System.nanoTime();
}

@Override
public void onModelLoaded(int status) {
mStatsDump.loadEnd = System.currentTimeMillis();
mStatsDump.loadEnd = System.nanoTime();
mStatsDump.loadStatus = status;
if (status != 0) {
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
onGenerationStopped();
return;
}
mStatsDump.generateStart = System.currentTimeMillis();
mStatsDump.generateStart = System.nanoTime();
mModelRunner.generate(mPrompt);
}

Expand All @@ -81,36 +89,122 @@ public void onStats(String stats) {

@Override
public void onGenerationStopped() {
mStatsDump.generateEnd = System.currentTimeMillis();
mStatsDump.generateEnd = System.nanoTime();
runOnUiThread(
() -> {
mTextView.append(mStatsDump.toString());
});

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsDump.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.modelName);
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ns)",
mStatsDump.loadEnd - mStatsDump.loadStart,
0.0f));
// LLM generate time
results.add(
new BenchmarkMetric(
benchmarkModel,
"generate_time(ns)",
mStatsDump.generateEnd - mStatsDump.generateStart,
0.0f));
// Token per second
results.add(
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsDump));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}

private double extractTPS(final String tokens) {
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
if (m.find()) {
return Double.parseDouble(m.group());
} else {
return 0.0f;
}
}
}

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actualValue;
double targetValue;

public static class DeviceInfo {
// Let's see which information we want to include here
final String device = Build.BRAND;
// The phone model and Android release version
final String arch = Build.MODEL;
final String os = "Android " + Build.VERSION.RELEASE;
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
final long availMem = new ActivityManager.MemoryInfo().availMem;
}

DeviceInfo deviceInfo = new DeviceInfo();

public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actualValue,
final double targetValue) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actualValue = actualValue;
this.targetValue = targetValue;
}

// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}

class StatsDump {
int loadStatus;
long loadStart;
long loadEnd;
long generateStart;
long generateEnd;
String tokens;
String modelName;

@NonNull
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,49 @@ protected void onCreate(Bundle savedInstanceState) {
// TODO: Format the string with a parsable format
Stats stats = new Stats();

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

for (int i = 0; i < numIter; i++) {
long start = System.currentTimeMillis();
long start = System.nanoTime();
module.forward();
long forwardMs = System.currentTimeMillis() - start;
long forwardMs = System.nanoTime() - start;
stats.latency.add(forwardMs);
}
stats.errorCode = module.loadMethod("forward");

// TODO (huydhn): Remove txt files here once the JSON format is ready
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably log the time for module.loadMethod() before first forward() 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let me try to copy it from llama and add one here too

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(stats.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ns)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel, "model_load_time(ns)", stats.loadEnd - stats.loadStart, 0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(stats));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}

class Stats {
long loadStart;
long loadEnd;
List<Long> latency = new ArrayList<>();
int errorCode = 0;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

import android.app.ActivityManager;
import android.os.Build;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actualValue;
double targetValue;

public static class DeviceInfo {
// Let's see which information we want to include here
final String device = Build.BRAND;
// The phone model and Android release version
final String arch = Build.MODEL;
final String os = "Android " + Build.VERSION.RELEASE;
final long totalMem = new ActivityManager.MemoryInfo().totalMem;
final long availMem = new ActivityManager.MemoryInfo().availMem;
}

DeviceInfo deviceInfo = new DeviceInfo();

public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actualValue,
final double targetValue) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actualValue = actualValue;
this.targetValue = targetValue;
}

// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}
Loading
Loading