Skip to content

Commit 1c66ae4

Browse files
committed
Update
[ghstack-poisoned]
1 parent 9ce50b4 commit 1c66ae4

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __init__(
154154
):
155155
super().__init__()
156156
self.max_seq_length = max_seq_length
157-
self.is_tranposed = transpose_cache
157+
self.is_transposed = transpose_cache
158158
if transpose_cache:
159159
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
160160
else:

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def update(self, input_pos, k_val, v_val):
193193
@classmethod
194194
def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
195195
cache_shape = kv_cache.k_cache.shape
196-
if kv_cache.is_tranposed:
196+
if kv_cache.is_transposed:
197197
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
198198
else:
199199
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
@@ -203,7 +203,7 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
203203
n_heads,
204204
head_dim,
205205
cache_type,
206-
kv_cache.is_tranposed,
206+
kv_cache.is_transposed,
207207
kv_cache.enable_dynamic_shape,
208208
)
209209

examples/models/llama/source_transformation/test_quantized_kv_cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def setUp(self):
4848
self.transpose_kv_cache = False
4949
self.dtype = torch.float32
5050

51-
def _test_simple_update_fetch(self, is_tranposed=False, is_dynamic_shape=False):
52-
self.transpose_kv_cache = is_tranposed
51+
def _test_simple_update_fetch(self, is_transposed=False, is_dynamic_shape=False):
52+
self.transpose_kv_cache = is_transposed
5353
self.enable_dynamic_shape = is_dynamic_shape
5454
input_pos = torch.tensor([0, 1, 2])
5555
self.seq_len = input_pos.size(0)
@@ -122,7 +122,7 @@ def test_simple_update_fetch_not_transposed_dynamic_shape(self):
122122
self._test_simple_update_fetch(is_dynamic_shape=True)
123123

124124
def test_simple_update_fetch_transposed(self):
125-
self._test_simple_update_fetch(is_tranposed=True)
125+
self._test_simple_update_fetch(is_transposed=True)
126126

127127
def test_simple_update_fetch_transposed_dynamic_shape(self):
128-
self._test_simple_update_fetch(is_tranposed=True, is_dynamic_shape=True)
128+
self._test_simple_update_fetch(is_transposed=True, is_dynamic_shape=True)

0 commit comments

Comments
 (0)