@@ -41,13 +41,15 @@ def __init__(
41
41
model : nn .Module ,
42
42
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
43
43
max_seq_length : Optional [int ] = None ,
44
+ use_kv_cache : bool = False ,
44
45
):
45
46
device = "cuda" if torch .cuda .is_available () else "cpu"
46
47
super ().__init__ (device = device )
47
48
self ._model = model
48
49
self ._tokenizer = tokenizer
49
50
self ._device = torch .device (device )
50
51
self ._max_seq_length = 2048 if max_seq_length is None else max_seq_length
52
+ self ._use_kv_cache = use_kv_cache
51
53
52
54
@property
53
55
def eot_token_id (self ):
@@ -83,7 +85,15 @@ def tok_decode(self, tokens):
83
85
return decoded
84
86
85
87
def _model_call (self , inps ):
86
- return self ._model (inps )
88
+ if self ._use_kv_cache :
89
+ result_logits = []
90
+ for pos in range (self ._max_seq_length ):
91
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
92
+ logits = self ._model (inps [:, pos : pos + 1 ], pos_tensor )
93
+ result_logits .append (logits )
94
+ return torch .cat (result_logits , dim = 1 )
95
+ else :
96
+ return self ._model (inps )
87
97
88
98
def _model_generate (self , context , max_length , eos_token_id ):
89
99
raise Exception ("unimplemented" )
@@ -100,9 +110,11 @@ def __init__(
100
110
model : str ,
101
111
tokenizer : Union [SentencePieceTokenizer , Tiktoken ],
102
112
max_seq_length : Optional [int ] = None ,
113
+ use_kv_cache : bool = False ,
103
114
):
104
115
super ().__init__ (None , tokenizer , max_seq_length )
105
116
self ._model = model # Expects model to be path to a .pte file
117
+ self ._use_kv_cache = use_kv_cache
106
118
107
119
from executorch .extension .pybindings .portable_lib import _load_for_executorch
108
120
@@ -111,9 +123,17 @@ def __init__(
111
123
def _model_call (self , inps ):
112
124
# Given inps (tokens), return the logits from a single forward call
113
125
# inps: Tensor of shape (1, max_seq_len - 1)
114
- # logits: Tensor of shape (1, max_seq_len - 1, 32000)
115
- result = self ._et_model .forward ((inps ,))
116
- return result [0 ]
126
+ # logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
127
+ if self ._use_kv_cache :
128
+ result_logits = []
129
+ for pos in range (self ._max_seq_length ):
130
+ pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
131
+ logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
132
+ result_logits .append (logits [0 ])
133
+ return torch .cat (result_logits , dim = 1 )
134
+ else :
135
+ result = self ._et_model .forward ((inps ,))
136
+ return result [0 ]
117
137
118
138
119
139
class ETRunnerEvalWrapper (GPTFastEvalWrapper ):
@@ -139,7 +159,7 @@ def _model_call(self, inps):
139
159
140
160
# Example:
141
161
# inps: Tensor of shape (1, N)
142
- # logits: Tensor of shape (1, N, 32000 )
162
+ # logits: Tensor of shape (1, N, vocab_size )
143
163
pass
144
164
145
165
@@ -212,6 +232,7 @@ def gen_eval_wrapper(
212
232
# Exported model takes at most (max_seq_length - 1) tokens.
213
233
# Note that the eager model takes at most max_seq_length tokens.
214
234
max_seq_length = args .max_seq_length - 1 ,
235
+ use_kv_cache = args .use_kv_cache ,
215
236
)
216
237
217
238
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
@@ -225,6 +246,7 @@ def gen_eval_wrapper(
225
246
model = model ,
226
247
tokenizer = tokenizer ,
227
248
max_seq_length = args .max_seq_length ,
249
+ use_kv_cache = args .use_kv_cache ,
228
250
)
229
251
230
252
0 commit comments