@@ -260,9 +260,6 @@ void check_q_4w_linear_args(
260
260
const int group_size_val = graph.extract_scalar <int >(group_size);
261
261
VK_CHECK_COND (K % group_size_val == 0 );
262
262
263
- VK_CHECK_COND (graph.packed_dim_of (mat1) == WHCN::kWidthDim );
264
- VK_CHECK_COND (graph.packed_dim_of (out) == WHCN::kWidthDim );
265
-
266
263
VK_CHECK_COND (graph.has_standard_axis_map (mat1));
267
264
VK_CHECK_COND (graph.has_standard_axis_map (out));
268
265
}
@@ -320,13 +317,32 @@ void add_q_4w_linear_node(
320
317
321
318
const uint32_t group_size_val = graph.extract_scalar <uint32_t >(group_size);
322
319
320
+ ValueRef mat1_W_packed = mat1;
321
+ ValueRef out_W_packed = out;
322
+ auto viewFn = VK_GET_OP_FN (" aten.view_copy.default" );
323
+ // Create temporary tensors to store the width packed versions of mat1 and out
324
+ TmpTensor mat1_tmp (
325
+ &graph, graph.sizes_of (mat1), graph.dtype_of (mat1), utils::kWidthPacked );
326
+ TmpTensor out_tmp (
327
+ &graph, graph.sizes_of (out), graph.dtype_of (out), utils::kWidthPacked );
328
+ if (storage_type == utils::kTexture3D ) {
329
+ if (!graph.is_buffer_storage (out) &&
330
+ graph.packed_dim_of (mat1) != WHCN::kWidthDim ) {
331
+ // Ensure mat1 is width packed
332
+ mat1_W_packed = mat1_tmp;
333
+ viewFn (graph, {mat1, graph.add_none (), mat1_W_packed});
334
+ // Ensure out is packed correctly
335
+ out_W_packed = out_tmp;
336
+ }
337
+ }
338
+
323
339
vkapi::ParamsBindList ubos ({});
324
- ubos.append (graph.logical_limits_ubo (out ));
325
- ubos.append (graph.sizes_ubo (mat1 ));
340
+ ubos.append (graph.logical_limits_ubo (out_W_packed ));
341
+ ubos.append (graph.sizes_ubo (mat1_W_packed ));
326
342
ubos.append (graph.strides_ubo (mat2));
327
343
ubos.append (graph.strides_ubo (scales_and_zeros));
328
344
329
- utils::uvec3 global_wg_size = graph.logical_limits_of (out );
345
+ utils::uvec3 global_wg_size = graph.logical_limits_of (out_W_packed );
330
346
utils::uvec3 local_wg_size = graph.create_local_wg_size (global_wg_size);
331
347
332
348
graph.execute_nodes ().emplace_back (new DispatchNode (
@@ -335,15 +351,19 @@ void add_q_4w_linear_node(
335
351
global_wg_size,
336
352
local_wg_size,
337
353
// Inputs and Outputs
338
- {{out , vkapi::MemoryAccessType::WRITE},
339
- {{mat1 , mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
354
+ {{out_W_packed , vkapi::MemoryAccessType::WRITE},
355
+ {{mat1_W_packed , mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
340
356
// Shader params buffers
341
357
ubos,
342
358
// Specialization Constants
343
359
{SV (group_size_val)},
344
360
// Resizing Logic
345
361
resize_q_4w_linear_node,
346
362
{}));
363
+ if (!graph.is_buffer_storage (out) &&
364
+ graph.packed_dim_of (out) != WHCN::kWidthDim ) {
365
+ viewFn (graph, {out_W_packed, graph.add_none (), out});
366
+ }
347
367
}
348
368
349
369
void linear_weight_int4 (
0 commit comments