@@ -71,46 +71,63 @@ void add_q_8w_linear_node(
71
71
const ValueRef q_mat2_data,
72
72
const ValueRef scales_data,
73
73
const ValueRef out) {
74
+ auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
75
+ ValueRef mat1_W_packed = mat1;
76
+ ValueRef out_W_packed = out;
77
+ if (!graph.is_buffer_storage (out) &&
78
+ graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
79
+ // Ensure mat1 is width packed
80
+ mat1_W_packed = graph.add_tensor_like (mat1, utils::kWidthPacked );
81
+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
82
+ // Ensure out is packed correctly
83
+ out_W_packed = graph.add_tensor_like (out, utils::kWidthPacked );
84
+ }
74
85
ValueRef q_mat2 =
75
86
prepack_if_tensor_ref (graph, q_mat2_data, utils::kWidthPacked );
76
87
ValueRef scales =
77
88
prepack_if_tensor_ref (graph, scales_data, utils::kWidthPacked );
78
89
79
90
std::string kernel_name = " q_8w_linear" ;
80
91
kernel_name.reserve (kShaderNameReserve );
81
- add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1 ));
92
+ add_packed_dim_suffix (kernel_name, graph.packed_dim_of (mat1_W_packed ));
82
93
add_packed_dim_suffix (kernel_name, graph.packed_dim_of (q_mat2));
83
- add_dtype_suffix (kernel_name, graph.dtype_of (out ));
84
- add_storage_type_suffix (kernel_name, graph.storage_type_of (out ));
94
+ add_dtype_suffix (kernel_name, graph.dtype_of (out_W_packed ));
95
+ add_storage_type_suffix (kernel_name, graph.storage_type_of (out_W_packed ));
85
96
86
97
vkapi::ParamsBindList ubos ({});
87
- if (graph.is_buffer_storage (out )) {
98
+ if (graph.is_buffer_storage (out_W_packed )) {
88
99
ubos.append (
89
- {graph.sizes_ubo (out ),
90
- graph.strides_ubo (out ),
91
- graph.numel_ubo (out ),
92
- graph.sizes_ubo (mat1 ),
100
+ {graph.sizes_ubo (out_W_packed ),
101
+ graph.strides_ubo (out_W_packed ),
102
+ graph.numel_ubo (out_W_packed ),
103
+ graph.sizes_ubo (mat1_W_packed ),
93
104
graph.strides_ubo (mat1),
94
105
graph.strides_ubo (q_mat2),
95
106
graph.strides_ubo (scales)});
96
107
} else {
97
- ubos.append ({graph.logical_limits_ubo (out), graph.sizes_ubo (mat1)});
108
+ ubos.append (
109
+ {graph.logical_limits_ubo (out_W_packed),
110
+ graph.sizes_ubo (mat1_W_packed)});
98
111
}
99
112
100
113
graph.execute_nodes ().emplace_back (new DispatchNode (
101
114
graph,
102
115
VK_KERNEL_FROM_STR (kernel_name),
103
- graph.create_global_wg_size (out ),
104
- graph.create_local_wg_size (out ),
116
+ graph.create_global_wg_size (out_W_packed ),
117
+ graph.create_local_wg_size (out_W_packed ),
105
118
// Inputs and Outputs
106
- {{out , vkapi::MemoryAccessType::WRITE},
107
- {{mat1 , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
119
+ {{out_W_packed , vkapi::MemoryAccessType::WRITE},
120
+ {{mat1_W_packed , q_mat2, scales}, vkapi::MemoryAccessType::READ}},
108
121
// Shader params buffers
109
122
ubos,
110
123
// Specialization Constants
111
124
{},
112
125
// Resizing Logic
113
126
resize_qlinear_node));
127
+ if (!graph.is_buffer_storage (out) &&
128
+ graph.packed_dim_of (out) != WHCN::kWidthDim ) {
129
+ viewFn (graph, {out_W_packed, graph.add_none (), out});
130
+ }
114
131
}
115
132
116
133
void weight_int8pack_mm (
0 commit comments