Skip to content

Commit 6574788

Browse files
cccclaifacebook-github-bot
authored andcommitted
use index_put only in kv cache update to reduce number of operators (#3786)
Summary: Pull Request resolved: #3786 The decomposition from ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, input_pos, value): x[:, :, input_pos] = value return x ``` is ``` opcode name target args kwargs ------------- --------------- -------------------------- ----------------------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function slice_1 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_2 aten.slice.Tensor (slice_1, 1, 0, 9223372036854775807) {} call_function index_put aten.index_put.default (slice_2, [None, None, input_pos], value) {} call_function slice_3 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_scatter aten.slice_scatter.default (slice_3, index_put, 1, 0, 9223372036854775807) {} call_function slice_scatter_1 aten.slice_scatter.default (x, slice_scatter, 0, 0, 9223372036854775807) {} output output output ((slice_scatter_1, slice_scatter_1),) {} ``` however `x[:, :, input_pos] = value` really is just updating the content inside `x` with value, essentially just `index_put` By replacing `x[:, :, input_pos] = value` with `torch.ops.aten.index_put_(x, [None, None, input_pos], value)`, we reduce the number of operators from 6 to 1. ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, indices, values): torch.ops.aten.index_put_(x, [None, None, input_pos], value) return x ``` decomposition is ``` opcode name target args kwargs ------------- --------- ---------------------- ----------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function index_put aten.index_put.default (x, [None, None, input_pos], value) {} output output output ((index_put, index_put),) {} ``` A more proper way to address this in long term is via pattern matching to replace the patterns with the simplified pattern Perf: For stories, before the diff ``` I 00:00:03.437290 executorch:runner.cpp:419] Prompt Tokens: 9 Generated Tokens: 118 I 00:00:03.437295 executorch:runner.cpp:425] Model Load Time: 0.763000 (seconds) I 00:00:03.437301 executorch:runner.cpp:435] Total inference time: 2.661000 (seconds) Rate: 44.344231 (tokens/second) I 00:00:03.437305 executorch:runner.cpp:443] Prompt evaluation: 0.185000 (seconds) Rate: 48.648649 (tokens/second) I 00:00:03.437309 executorch:runner.cpp:454] Generated 118 tokens: 2.476000 (seconds) Rate: 47.657512 (tokens/second) I 00:00:03.437313 executorch:runner.cpp:462] Time to first generated token: 0.206000 (seconds) I 00:00:03.437315 executorch:runner.cpp:469] Sampling time over 127 tokens: 0.042000 (seconds) ``` After the diff ``` I 00:00:03.195257 executorch:runner.cpp:419] Prompt Tokens: 9 Generated Tokens: 118 I 00:00:03.195295 executorch:runner.cpp:425] Model Load Time: 0.683000 (seconds) I 00:00:03.195314 executorch:runner.cpp:435] Total inference time: 2.502000 (seconds) Rate: 47.162270 (tokens/second) I 00:00:03.195319 executorch:runner.cpp:443] Prompt evaluation: 0.175000 (seconds) Rate: 51.428571 (tokens/second) I 00:00:03.195323 executorch:runner.cpp:454] Generated 118 tokens: 2.327000 (seconds) Rate: 50.709067 (tokens/second) I 00:00:03.195327 executorch:runner.cpp:462] Time to first generated token: 0.195000 (seconds) I 00:00:03.195330 executorch:runner.cpp:469] Sampling time over 127 tokens: 0.049000 (seconds) ``` Differential Revision: D57949659
1 parent 0333390 commit 6574788

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,15 @@ def update(
191191
narrowed_v.copy_(v_val)
192192
return self.k_cache, self.v_cache
193193
else:
194-
k_out = self.k_cache
195-
v_out = self.v_cache
196-
k_out[:, :, input_pos] = k_val
197-
v_out[:, :, input_pos] = v_val
194+
k_out = torch.ops.aten.index_put_(
195+
self.k_cache, [None, None, input_pos], k_val
196+
)
197+
v_out = torch.ops.aten.index_put_(
198+
self.v_cache, [None, None, input_pos], v_val
199+
)
200+
v_out = torch.ops.aten.index_put_(
201+
self.v_cache, [None, None, input_pos], v_val
202+
)
198203

199204
return k_out, v_out
200205

0 commit comments

Comments
 (0)