14
14
import android .util .Log ;
15
15
import android .widget .TextView ;
16
16
import androidx .annotation .NonNull ;
17
+ import com .google .gson .Gson ;
18
+ import java .io .File ;
17
19
import java .io .FileWriter ;
18
20
import java .io .IOException ;
21
+ import java .util .Arrays ;
19
22
20
23
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
21
24
ModelRunner mModelRunner ;
@@ -32,7 +35,12 @@ protected void onCreate(Bundle savedInstanceState) {
32
35
33
36
Intent intent = getIntent ();
34
37
35
- String modelPath = intent .getStringExtra ("model_path" );
38
+ File modelDir = new File (intent .getStringExtra ("model_dir" ));
39
+ File model =
40
+ Arrays .stream (modelDir .listFiles ())
41
+ .filter (file -> file .getName ().endsWith (".pte" ))
42
+ .findFirst ()
43
+ .get ();
36
44
String tokenizerPath = intent .getStringExtra ("tokenizer_path" );
37
45
38
46
float temperature = intent .getFloatExtra ("temperature" , 0.8f );
@@ -42,7 +50,7 @@ protected void onCreate(Bundle savedInstanceState) {
42
50
}
43
51
44
52
mStatsDump = new StatsDump ();
45
- mModelRunner = new ModelRunner (modelPath , tokenizerPath , temperature , this );
53
+ mModelRunner = new ModelRunner (model . getPath () , tokenizerPath , temperature , this );
46
54
mStatsDump .loadStart = System .currentTimeMillis ();
47
55
}
48
56
@@ -79,11 +87,21 @@ public void onGenerationStopped() {
79
87
mTextView .append (mStatsDump .toString ());
80
88
});
81
89
90
+ // TODO (huydhn): Remove txt files here once the JSON format is ready
82
91
try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.txt" )) {
83
92
writer .write (mStatsDump .toString ());
84
93
} catch (IOException e ) {
85
94
e .printStackTrace ();
86
95
}
96
+
97
+ // TODO (huydhn): Figure out on what the final JSON results looks like, we need something
98
+ // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
99
+ try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
100
+ Gson gson = new Gson ();
101
+ writer .write (gson .toJson (mStatsDump ));
102
+ } catch (IOException e ) {
103
+ e .printStackTrace ();
104
+ }
87
105
}
88
106
}
89
107
0 commit comments