File tree Expand file tree Collapse file tree 2 files changed +6
-10
lines changed Expand file tree Collapse file tree 2 files changed +6
-10
lines changed Original file line number Diff line number Diff line change @@ -54,12 +54,11 @@ def _model_call(self, inps):
54
54
# inps: Tensor of shape (1, max_seq_len - 1)
55
55
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
56
56
if self ._use_kv_cache :
57
- result_logits = []
58
- for pos in range (self ._max_seq_length ):
59
- pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
60
- logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
61
- result_logits .append (logits [0 ])
62
- return torch .cat (result_logits , dim = 1 )
57
+ pos_tensor = torch .tensor ([0 ], dtype = torch .int64 , device = self .device )
58
+ result = self ._et_model .forward (
59
+ (inps [:, : self ._max_seq_length ], pos_tensor )
60
+ )
61
+ return result [0 ]
63
62
else :
64
63
result = self ._et_model .forward ((inps ,))
65
64
return result [0 ]
Original file line number Diff line number Diff line change @@ -77,10 +77,7 @@ def tok_decode(self, tokens):
77
77
78
78
def _model_call (self , inps ):
79
79
if self ._use_kv_cache :
80
- pos_tensor = torch .arange (
81
- self ._max_seq_length , dtype = torch .int64 , device = self .device
82
- )
83
-
80
+ pos_tensor = torch .tensor ([0 ], dtype = torch .int64 , device = self .device )
84
81
# Batch process the whole sequence.
85
82
logits = self ._model (inps [:, : self ._max_seq_length ], pos_tensor )
86
83
return logits
You can’t perform that action at this time.
0 commit comments