@@ -41,13 +41,15 @@ def __init__(
41
41
model : nn .Module ,
42
42
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
43
43
max_seq_length : Optional [int ] = None ,
44
+ use_kv_cache : bool = False ,
44
45
):
45
46
device = "cuda" if torch .cuda .is_available () else "cpu"
46
47
super ().__init__ (device = device )
47
48
self ._model = model
48
49
self ._tokenizer = tokenizer
49
50
self ._device = torch .device (device )
50
51
self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52
+ self ._use_kv_cache = use_kv_cache
51
53
52
54
@property
53
55
def eot_token_id (self ):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
83
85
return decoded
84
86
85
87
def _model_call (self , inps ):
86
- return self ._model (inps )
88
+ if self ._use_kv_cache :
89
+ result_logits = []
90
+ for pos in range (self ._max_seq_length ):
91
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
92
+ logits = self ._model (inps [:, pos : pos + 1 ], pos_tensor )
93
+ result_logits .append (logits )
94
+ return torch .cat (result_logits , dim = 1 )
95
+ else :
96
+ return self ._model (inps )
87
97
88
98
def _model_generate (self , context , max_length , eos_token_id ):
89
99
raise Exception ("unimplemented" )
@@ -100,6 +110,7 @@ def __init__(
100
110
model : str ,
101
111
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
102
112
max_seq_length : Optional [int ] = None ,
113
+ use_kv_cache : bool = False ,
103
114
):
104
115
super ().__init__ (None , tokenizer , max_seq_length )
105
116
self ._model = model # Expects model to be path to a .pte file
@@ -111,9 +122,17 @@ def __init__(
111
122
def _model_call (self , inps ):
112
123
# Given inps (tokens), return the logits from a single forward call
113
124
# inps: Tensor of shape (1, max_seq_len - 1)
114
- # logits: Tensor of shape (1, max_seq_len - 1, 32000)
115
- result = self ._et_model .forward ((inps ,))
116
- return result [0 ]
125
+ # logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126
+ if self ._use_kv_cache :
127
+ result_logits = []
128
+ for pos in range (self ._max_seq_length ):
129
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
130
+ logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
131
+ result_logits .append (logits [0 ])
132
+ return torch .cat (result_logits , dim = 1 )
133
+ else :
134
+ result = self ._et_model .forward ((inps ,))
135
+ return result [0 ]
117
136
118
137
119
138
class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139
158
140
159
# Example:
141
160
# inps: Tensor of shape (1, N)
142
- # logits: Tensor of shape (1, N, 32000 )
161
+ # logits: Tensor of shape (1, N, vocab_size )
143
162
pass
144
163
145
164
@@ -203,6 +222,7 @@ def gen_eval_wrapper(
203
222
tokenizer = tokenizer ,
204
223
tokenizer_bin = tokenizer_bin ,
205
224
max_seq_length = args .max_seq_length ,
225
+ use_kv_cache = args .use_kv_cache ,
206
226
)
207
227
208
228
# ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated eagerly
@@ -225,6 +245,7 @@ def gen_eval_wrapper(
225
245
model = model ,
226
246
tokenizer = tokenizer ,
227
247
max_seq_length = args .max_seq_length ,
248
+ use_kv_cache = args .use_kv_cache ,
228
249
)
229
250
230
251
0 commit comments