Skip to content

Commit 4bb555c

Browse files
committed
Fix instrumentation
1 parent 5e91666 commit 4bb555c

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

extension/android/executorch_android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies {
4747
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
4848
androidTestImplementation 'androidx.test:rules:1.2.0'
4949
androidTestImplementation 'commons-io:commons-io:2.4'
50+
androidTestImplementation 'org.json:json:20250107'
5051
}
5152

5253
import com.vanniktech.maven.publish.SonatypeHost

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.apache.commons.io.FileUtils;
3535
import androidx.test.ext.junit.runners.AndroidJUnit4;
3636
import androidx.test.InstrumentationRegistry;
37+
import org.json.JSONException;
38+
import org.json.JSONObject;
3739
import org.pytorch.executorch.extension.llm.LlmCallback;
3840
import org.pytorch.executorch.extension.llm.LlmModule;
3941

@@ -94,9 +96,17 @@ public void onResult(String result) {
9496
}
9597

9698
@Override
97-
public void onStats(String result) {
98-
// TODO: Calculate tps
99-
// LlmModuleInstrumentationTest.this.onStats(tps);
99+
public void onStats(String stats) {
100+
float tps = 0;
101+
try {
102+
JSONObject jsonObject = new JSONObject(stats);
103+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
104+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
105+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
106+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
107+
LlmModuleInstrumentationTest.this.onStats(tps);
108+
} catch (JSONException e) {
109+
}
100110
}
101111
});
102112

0 commit comments

Comments
 (0)