Skip to content

Commit e342a92

Browse files
Kush Rastogifacebook-github-bot
authored andcommitted
Width Packing Mat1 input for Quantized Linear (#6149)
Summary: Pull Request resolved: #6149 Width packing mat1 input for Quantized Linear as ASR model provides channel-packed matrix while operator does not support channel-packed yet. Reviewed By: nathanaelsee, jorgep31415 Differential Revision: D64065606 fbshipit-source-id: 2a7d43d432deef7245d1d45f5c760b0f42627551
1 parent 517fddb commit e342a92

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,46 +71,63 @@ void add_q_8w_linear_node(
7171
const ValueRef q_mat2_data,
7272
const ValueRef scales_data,
7373
const ValueRef out) {
74+
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
75+
ValueRef mat1_W_packed = mat1;
76+
ValueRef out_W_packed = out;
77+
if (!graph.is_buffer_storage(out) &&
78+
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
79+
// Ensure mat1 is width packed
80+
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
81+
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
82+
// Ensure out is packed correctly
83+
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
84+
}
7485
ValueRef q_mat2 =
7586
prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked);
7687
ValueRef scales =
7788
prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked);
7889

7990
std::string kernel_name = "q_8w_linear";
8091
kernel_name.reserve(kShaderNameReserve);
81-
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1));
92+
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
8293
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
83-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
84-
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
94+
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
95+
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));
8596

8697
vkapi::ParamsBindList ubos({});
87-
if (graph.is_buffer_storage(out)) {
98+
if (graph.is_buffer_storage(out_W_packed)) {
8899
ubos.append(
89-
{graph.sizes_ubo(out),
90-
graph.strides_ubo(out),
91-
graph.numel_ubo(out),
92-
graph.sizes_ubo(mat1),
100+
{graph.sizes_ubo(out_W_packed),
101+
graph.strides_ubo(out_W_packed),
102+
graph.numel_ubo(out_W_packed),
103+
graph.sizes_ubo(mat1_W_packed),
93104
graph.strides_ubo(mat1),
94105
graph.strides_ubo(q_mat2),
95106
graph.strides_ubo(scales)});
96107
} else {
97-
ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)});
108+
ubos.append(
109+
{graph.logical_limits_ubo(out_W_packed),
110+
graph.sizes_ubo(mat1_W_packed)});
98111
}
99112

100113
graph.execute_nodes().emplace_back(new DispatchNode(
101114
graph,
102115
VK_KERNEL_FROM_STR(kernel_name),
103-
graph.create_global_wg_size(out),
104-
graph.create_local_wg_size(out),
116+
graph.create_global_wg_size(out_W_packed),
117+
graph.create_local_wg_size(out_W_packed),
105118
// Inputs and Outputs
106-
{{out, vkapi::MemoryAccessType::WRITE},
107-
{{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
119+
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
120+
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
108121
// Shader params buffers
109122
ubos,
110123
// Specialization Constants
111124
{},
112125
// Resizing Logic
113126
resize_qlinear_node));
127+
if (!graph.is_buffer_storage(out) &&
128+
graph.packed_dim_of(out) != WHCN::kWidthDim) {
129+
viewFn(graph, {out_W_packed, graph.add_none(), out});
130+
}
114131
}
115132

116133
void weight_int8pack_mm(

0 commit comments

Comments
 (0)