Skip to content

Commit fe111eb

Browse files
author
Nathanael See
committed
Update on "[ET-VK][int4] Wrap int4 linear calls with view_copy nodes to squeeze/unsqueeze inputs"
This is done automatically for full-precision linear/mm nodes in the graph at torch.export graph tracing time, but is not done for the int4 op. The new pass adds view_copy nodes, as there are subsequent passes which can fuse view_copy nodes if redundant, and convert view_copy nodes to squeeze/unsqueeze nodes. Differential Revision: [D69065866](https://our.internmc.facebook.com/intern/diff/D69065866/) [ghstack-poisoned]
2 parents 436f436 + 9873869 commit fe111eb

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

backends/vulkan/_passes/squeeze_int4_linear_inputs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _squeezable(shape: List[int]) -> bool:
3434
output_shape = meta["val"].shape
3535
if not _squeezable(input_shape):
3636
return super().call_operator(op, args, kwargs, meta)
37-
37+
3838
# squeeze input tensor
3939
squeeze_shape = list(input_shape)
4040
while _squeezable(squeeze_shape):
@@ -43,23 +43,23 @@ def _squeezable(shape: List[int]) -> bool:
4343
squeeze_out = super().call_operator(
4444
exir_ops.edge.aten.view_copy.default,
4545
(args[0], squeeze_shape),
46-
kwargs,
46+
kwargs,
4747
meta,
4848
)
4949
# call linear on squeezed output
5050
new_args = (squeeze_out, *args[1:])
5151
linear_out = super().call_operator(
5252
op,
53-
new_args,
54-
kwargs,
53+
new_args,
54+
kwargs,
5555
meta,
5656
)
5757
# unsqueeze output
5858
unsqueeze_shape = list(output_shape)
5959
return super().call_operator(
6060
exir_ops.edge.aten.view_copy.default,
6161
(linear_out, unsqueeze_shape),
62-
kwargs,
62+
kwargs,
6363
meta,
6464
)
6565

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ void add_q_4w_linear_node(
352352
local_wg_size,
353353
// Inputs and Outputs
354354
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
355-
{{mat1_W_packed, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
355+
{{mat1_W_packed, mat2, scales_and_zeros},
356+
vkapi::MemoryAccessType::READ}},
356357
// Shader params buffers
357358
ubos,
358359
// Specialization Constants

0 commit comments

Comments
 (0)