Skip to content

Commit ac41e02

Browse files
committed
Update on "[ET-VK] Migrate ops to use DynamicDispatchNode"
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) [ghstack-poisoned]
2 parents a334e50 + c81cc04 commit ac41e02

File tree

2 files changed

+0
-4
lines changed

2 files changed

+0
-4
lines changed

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,3 @@ unary_op:
4646
OPERATOR: leaky_relu(X, A)
4747
- NAME: round
4848
OPERATOR: round(X)
49-
- NAME: tan
50-
OPERATOR: tan(X)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ DEFINE_ACTIVATION_FN(hardswish);
154154
DEFINE_ACTIVATION_FN(hardsigmoid);
155155
DEFINE_LEAKY_RELU_FN(leaky_relu);
156156
DEFINE_ACTIVATION_FN(round);
157-
DEFINE_ACTIVATION_FN(tan);
158157

159158
REGISTER_OPERATORS {
160159
VK_REGISTER_OP(aten.abs.default, abs);
@@ -175,7 +174,6 @@ REGISTER_OPERATORS {
175174
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
176175
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
177176
VK_REGISTER_OP(aten.round.default, round);
178-
VK_REGISTER_OP(aten.tan.default, tan);
179177
}
180178

181179
} // namespace vkcompute

0 commit comments

Comments
 (0)