Skip to content

Commit d9aed5a

Browse files
cccclaifacebook-github-bot
authored andcommitted
use index_put only in kv cache update to reduce number of operators
Summary: The decomposition from ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, input_pos, value): x[:, :, input_pos] = value return x ``` is ``` opcode name target args kwargs ------------- --------------- -------------------------- ----------------------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function slice_1 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_2 aten.slice.Tensor (slice_1, 1, 0, 9223372036854775807) {} call_function index_put aten.index_put.default (slice_2, [None, None, input_pos], value) {} call_function slice_3 aten.slice.Tensor (x, 0, 0, 9223372036854775807) {} call_function slice_scatter aten.slice_scatter.default (slice_3, index_put, 1, 0, 9223372036854775807) {} call_function slice_scatter_1 aten.slice_scatter.default (x, slice_scatter, 0, 0, 9223372036854775807) {} output output output ((slice_scatter_1, slice_scatter_1),) {} ``` however `x[:, :, input_pos] = value` really is just updating the content inside `x` with value, essentially just `index_put` By replacing `x[:, :, input_pos] = value` with `torch.ops.aten.index_put_(x, [None, None, input_pos], value)`, we reduce the number of operators from 6 to 1. ``` class IndexPut(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, indices, values): torch.ops.aten.index_put_(x, [None, None, input_pos], value) return x ``` decomposition is ``` opcode name target args kwargs ------------- --------- ---------------------- ----------------------------------- -------- placeholder x x () {} placeholder input_pos input_pos () {} placeholder value value () {} call_function index_put aten.index_put.default (x, [None, None, input_pos], value) {} output output output ((index_put, index_put),) {} ``` A more proper way to address this in long term is via pattern matching to replace the patterns with the simplified pattern Differential Revision: D57949659
1 parent 9e86860 commit d9aed5a

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

examples/models/llama2/llama_transformer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,8 @@ def update(
208208
self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
209209
) -> Tuple[torch.Tensor, torch.Tensor]:
210210
# input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
211-
k_out = self.k_cache
212-
v_out = self.v_cache
213-
k_out[:, :, input_pos] = k_val
214-
v_out[:, :, input_pos] = v_val
215-
211+
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
212+
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
216213
return k_out, v_out
217214

218215

0 commit comments

Comments
 (0)