@@ -31,12 +31,42 @@ class Runner {
31
31
const std::string& tokenizer_path,
32
32
const float temperature = 0 .8f );
33
33
34
+ struct TimeStampsAndStats {
35
+ // Scaling factor for timestamps - in this case, we use ms.
36
+ const long SCALING_FACTOR_UNITS_PER_SECOND = 1000 ;
37
+ // Time stamps for the different stages of the execution
38
+ // model_load_start_ms: Start of model loading.
39
+ long model_load_start_ms;
40
+ // model_load_end_ms: End of model loading.
41
+ long model_load_end_ms;
42
+ // inference_start_ms: Immediately after the model is loaded (or we check
43
+ // for model load), measure the inference time.
44
+ long inference_start_ms;
45
+ // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
46
+ // before the inference loop starts
47
+ long prompt_eval_end_ms;
48
+ // first_token: Timestamp when the first generated token is emitted
49
+ long first_token_ms;
50
+ // inference_end_ms: End of inference/generation.
51
+ long inference_end_ms;
52
+ // Keep a running total of the time spent in sampling.
53
+ long aggregate_sampling_time_ms;
54
+ // Token count from prompt
55
+ int64_t num_prompt_tokens;
56
+ // Token count from generated (total - prompt)
57
+ int64_t num_generated_tokens;
58
+
59
+ void printReport ();
60
+ const std::string toJsonString ();
61
+ };
62
+
34
63
bool is_loaded () const ;
35
64
Error load ();
36
65
Error generate (
37
66
const std::string& prompt,
38
67
int32_t seq_len = 128 ,
39
- std::function<void (const std::string&)> callback = {});
68
+ std::function<void (const std::string&)> on_token_generated_callback = {},
69
+ std::function<void (const TimeStampsAndStats&)> on_stats_callback = {});
40
70
void stop ();
41
71
42
72
private:
@@ -68,36 +98,7 @@ class Runner {
68
98
std::unique_ptr<Tokenizer> tokenizer_;
69
99
std::unique_ptr<Sampler> sampler_;
70
100
bool shouldStop_{false };
71
-
72
- struct TimeStamps {
73
- // Scaling factor for timestamps - in this case, we use ms.
74
- const long SCALING_FACTOR_UNITS_PER_SECOND = 1000 ;
75
- // Time stamps for the different stages of the execution
76
- // model_load_start_ms: Start of model loading.
77
- long model_load_start_ms;
78
- // model_load_end_ms: End of model loading.
79
- long model_load_end_ms;
80
- // inference_start_ms: Immediately after the model is loaded (or we check
81
- // for model load), measure the inference time.
82
- long inference_start_ms;
83
- // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
84
- // before the inference loop starts
85
- long prompt_eval_end_ms;
86
- // first_token: Timestamp when the first generated token is emitted
87
- long first_token_ms;
88
- // inference_end_ms: End of inference/generation.
89
- long inference_end_ms;
90
- // Keep a running total of the time spent in sampling.
91
- long aggregate_sampling_time_ms;
92
-
93
- void printReport (
94
- const int64_t & num_prompt_tokens,
95
- const int64_t & num_generated_tokens);
96
- const std::string toJsonString (
97
- const int64_t & num_prompt_tokens,
98
- const int64_t & num_generated_tokens);
99
- };
100
- TimeStamps timers_;
101
+ TimeStampsAndStats timers_;
101
102
};
102
103
103
104
} // namespace torch::executor
0 commit comments