Skip to content

Commit 4eeea39

Browse files
author
Nathanael See
committed
Update on "[ET-VK][int4] Wrap int4 linear calls with view_copy nodes to squeeze/unsqueeze inputs"
This is done automatically for full-precision linear/mm nodes in the graph at torch.export graph tracing time, but is not done for the int4 op. The new pass adds view_copy nodes, as there are subsequent passes which can fuse view_copy nodes if redundant, and convert view_copy nodes to squeeze/unsqueeze nodes. Differential Revision: [D69065866](https://our.internmc.facebook.com/intern/diff/D69065866/) [ghstack-poisoned]
2 parents fe111eb + d15bdb5 commit 4eeea39

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

backends/vulkan/_passes/squeeze_int4_linear_inputs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,3 @@ def _squeezable(shape: List[int]) -> bool:
6262
kwargs,
6363
meta,
6464
)
65-

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ void add_q_4w_linear_node(
352352
local_wg_size,
353353
// Inputs and Outputs
354354
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
355-
{{mat1_W_packed, mat2, scales_and_zeros},
356-
vkapi::MemoryAccessType::READ}},
355+
{{mat1_W_packed, mat2, scales_and_zeros},
356+
vkapi::MemoryAccessType::READ}},
357357
// Shader params buffers
358358
ubos,
359359
// Specialization Constants

0 commit comments

Comments
 (0)