@@ -41,13 +41,15 @@ def __init__(
41
41
model : nn .Module ,
42
42
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
43
43
max_seq_length : Optional [int ] = None ,
44
+ use_kv_cache : bool = False ,
44
45
):
45
46
device = "cuda" if torch .cuda .is_available () else "cpu"
46
47
super ().__init__ (device = device )
47
48
self ._model = model
48
49
self ._tokenizer = tokenizer
49
50
self ._device = torch .device (device )
50
51
self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52
+ self ._use_kv_cache = use_kv_cache
51
53
52
54
@property
53
55
def eot_token_id (self ):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
83
85
return decoded
84
86
85
87
def _model_call (self , inps ):
86
- return self ._model (inps )
88
+ if self ._use_kv_cache :
89
+ result_logits = []
90
+ for pos in range (self ._max_seq_length ):
91
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
92
+ logits = self ._model (inps [:, pos : pos + 1 ], pos_tensor )
93
+ result_logits .append (logits )
94
+ return torch .cat (result_logits , dim = 1 )
95
+ else :
96
+ return self ._model (inps )
87
97
88
98
def _model_generate (self , context , max_length , eos_token_id ):
89
99
raise Exception ("unimplemented" )
@@ -107,13 +117,22 @@ def __init__(
107
117
from executorch .extension .pybindings .portable_lib import _load_for_executorch
108
118
109
119
self ._et_model = _load_for_executorch (self ._model )
120
+ self ._use_kv_cache = self ._et_model .run_method ("use_kv_cache" )[0 ]
110
121
111
122
def _model_call (self , inps ):
112
123
# Given inps (tokens), return the logits from a single forward call
113
124
# inps: Tensor of shape (1, max_seq_len - 1)
114
- # logits: Tensor of shape (1, max_seq_len - 1, 32000)
115
- result = self ._et_model .forward ((inps ,))
116
- return result [0 ]
125
+ # logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
126
+ if self ._use_kv_cache :
127
+ result_logits = []
128
+ for pos in range (self ._max_seq_length ):
129
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
130
+ logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
131
+ result_logits .append (logits [0 ])
132
+ return torch .cat (result_logits , dim = 1 )
133
+ else :
134
+ result = self ._et_model .forward ((inps ,))
135
+ return result [0 ]
117
136
118
137
119
138
class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
@@ -139,7 +158,7 @@ def _model_call(self, inps):
139
158
140
159
# Example:
141
160
# inps: Tensor of shape (1, N)
142
- # logits: Tensor of shape (1, N, 32000 )
161
+ # logits: Tensor of shape (1, N, vocab_size )
143
162
pass
144
163
145
164
@@ -225,6 +244,7 @@ def gen_eval_wrapper(
225
244
model = model ,
226
245
tokenizer = tokenizer ,
227
246
max_seq_length = args .max_seq_length ,
247
+ use_kv_cache = args .use_kv_cache ,
228
248
)
229
249
230
250
0 commit comments