@@ -9697,32 +9697,6 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
9697
9697
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
9698
9698
}
9699
9699
9700
- // Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701
- static bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9702
- ggml_tensor *norm = cgraph->nodes[i];
9703
-
9704
- if (norm->op != GGML_OP_RMS_NORM) {
9705
- return false;
9706
- }
9707
-
9708
- if (!ggml_can_fuse_node(norm, 1)) {
9709
- return false;
9710
- }
9711
-
9712
- if (i + 1 >= cgraph->n_nodes) {
9713
- return false;
9714
- }
9715
- ggml_tensor *mul = cgraph->nodes[i + 1];
9716
- if (mul->op != GGML_OP_MUL || mul->src[0] != norm) {
9717
- return false;
9718
- }
9719
-
9720
- // Since norm is the first operand of mul, it must be the same shape
9721
- GGML_ASSERT(ggml_are_same_shape(mul, norm));
9722
-
9723
- return true;
9724
- }
9725
-
9726
9700
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
9727
9701
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
9728
9702
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9736,7 +9710,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9736
9710
9737
9711
uint64_t total_mat_mul_bytes = 0;
9738
9712
for (int i = 0; i < cgraph->n_nodes; i++) {
9739
- if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9713
+ if (ggml_can_fuse( cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL } )) {
9740
9714
ctx->num_additional_fused_ops = 1;
9741
9715
}
9742
9716
ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
@@ -9806,7 +9780,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
9806
9780
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
9807
9781
}
9808
9782
9809
- if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9783
+ if (ggml_can_fuse( cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL } )) {
9810
9784
ctx->num_additional_fused_ops = 1;
9811
9785
}
9812
9786
0 commit comments