Skip to content

Commit fde0f80

Browse files
committed
address review item and linting
1 parent f0873c3 commit fde0f80

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,21 @@ void add_binary_op_buffer_node(
132132
// Shader params buffers
133133
{},
134134
// Specialization Constants
135-
{graph.packed_dim_of(out), graph.packed_dim_of(in1), graph.packed_dim_of(in2)},
135+
{graph.packed_dim_of(out),
136+
graph.packed_dim_of(in1),
137+
graph.packed_dim_of(in2)},
136138
// Resizing Logic
137139
resize_binary_op_node,
138140
{},
139-
{{graph.sizes_pc_of(in1),
140-
graph.sizes_pc_of(in2),
141-
graph.strides_pc_of(out),
142-
graph.strides_pc_of(in1),
143-
graph.strides_pc_of(in2),
144-
graph.numel_pc_of(out),
145-
PushConstantDataInfo(&alpha_val, sizeof(float)),
146-
}}));
141+
{{
142+
graph.sizes_pc_of(in1),
143+
graph.sizes_pc_of(in2),
144+
graph.strides_pc_of(out),
145+
graph.strides_pc_of(in1),
146+
graph.strides_pc_of(in2),
147+
graph.numel_pc_of(out),
148+
PushConstantDataInfo(&alpha_val, sizeof(float)),
149+
}}));
147150
}
148151

149152
void add_binary_op_node(

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
def apply_rotary_emb_single(
2020
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
2121
) -> torch.Tensor:
22-
# Change to RoPE of huggingface version
22+
# The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way.
23+
# The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend.
24+
# Ref: https://github.com/huggingface/transformers/issues/25199
2325
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
24-
# brodcast for batch_prefill mode input x
26+
# broadcast for batch_prefill mode input x
2527
if x.dim() == 4:
2628
freqs_cos = freqs_cos[None, None, :, :]
2729
freqs_sin = freqs_sin[None, None, :, :]

0 commit comments

Comments
 (0)