File tree Expand file tree Collapse file tree 2 files changed +4
-10
lines changed Expand file tree Collapse file tree 2 files changed +4
-10
lines changed Original file line number Diff line number Diff line change @@ -58,12 +58,9 @@ def _model_call(self, inps):
58
58
# inps: Tensor of shape (1, max_seq_len - 1)
59
59
# logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
60
60
if self ._use_kv_cache :
61
- result_logits = []
62
- for pos in range (self ._max_seq_length ):
63
- pos_tensor = torch .tensor ([pos ], dtype = torch .int64 )
64
- logits = self ._et_model .forward ((inps [:, pos : pos + 1 ], pos_tensor ))
65
- result_logits .append (logits [0 ])
66
- return torch .cat (result_logits , dim = 1 )
61
+ pos_tensor = torch .tensor ([0 ], dtype = torch .int64 , device = self .device )
62
+ result = self ._et_model .forward ((inps [:, : self ._max_seq_length ], pos_tensor ))
63
+ return result [0 ]
67
64
else :
68
65
result = self ._et_model .forward ((inps ,))
69
66
return result [0 ]
Original file line number Diff line number Diff line change @@ -77,10 +77,7 @@ def tok_decode(self, tokens):
77
77
78
78
def _model_call (self , inps ):
79
79
if self ._use_kv_cache :
80
- pos_tensor = torch .arange (
81
- self ._max_seq_length , dtype = torch .int64 , device = self .device
82
- )
83
-
80
+ pos_tensor = torch .tensor ([0 ], dtype = torch .int64 , device = self .device )
84
81
# Batch process the whole sequence.
85
82
logits = self ._model (inps [:, : self ._max_seq_length ], pos_tensor )
86
83
return logits
You can’t perform that action at this time.
0 commit comments