Skip to content

Commit 5e13dcf

Browse files
committed
move ggml_can_fuse to a common function
1 parent b84cb4a commit 5e13dcf

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

ggml/src/ggml-impl.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include <stdint.h>
1313
#include <string.h>
1414

15+
#ifdef __cplusplus
16+
#include <initializer_list>
17+
#endif
18+
1519
#ifdef __ARM_FEATURE_SVE
1620
#include <arm_sve.h>
1721
#endif // __ARM_FEATURE_SVE
@@ -467,9 +471,10 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
467471
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
468472
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
469473

474+
#ifdef __cplusplus
470475
// return true if the node's results are only used by N other nodes
471476
// and can be fused into their calculations.
472-
static inline bool ggml_can_fuse_node(const struct ggml_tensor * node, int32_t N) {
477+
static inline bool ggml_node_has_N_uses(const struct ggml_tensor * node, int32_t N) {
473478
// check the use count against how many we're replacing
474479
if (node->use_count != N) {
475480
return false;
@@ -489,6 +494,40 @@ static inline bool ggml_can_fuse_node(const struct ggml_tensor * node, int32_t N
489494
return true;
490495
}
491496

497+
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
498+
// and are fusable. Nodes are considered fusable according to this function if:
499+
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
500+
// - all nodes except the last are src[0] of the following node.
501+
// - all nodes are the same shape.
502+
// TODO: Consider allowing GGML_OP_NONE nodes in between
503+
static bool ggml_can_fuse(struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
504+
size_t num_ops = ops.size();
505+
if (node_idx + num_ops > cgraph->n_nodes) {
506+
return false;
507+
}
508+
509+
for (size_t i = 0; i < num_ops; ++i) {
510+
struct ggml_tensor *node = cgraph->nodes[node_idx + i];
511+
if (node->op != ops.begin()[i]) {
512+
return false;
513+
}
514+
if (i < num_ops && !ggml_node_has_N_uses(node, 1)) {
515+
return false;
516+
}
517+
if (i > 0) {
518+
struct ggml_tensor *prev = cgraph->nodes[node_idx + i - 1];
519+
if (node->src[0] != prev) {
520+
return false;
521+
}
522+
if (!ggml_are_same_shape(node, prev)) {
523+
return false;
524+
}
525+
}
526+
}
527+
return true;
528+
}
529+
#endif
530+
492531
#ifdef __cplusplus
493532
}
494533
#endif

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9697,32 +9697,6 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96979697
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
96989698
}
96999699

9700-
// Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701-
static bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9702-
ggml_tensor *norm = cgraph->nodes[i];
9703-
9704-
if (norm->op != GGML_OP_RMS_NORM) {
9705-
return false;
9706-
}
9707-
9708-
if (!ggml_can_fuse_node(norm, 1)) {
9709-
return false;
9710-
}
9711-
9712-
if (i + 1 >= cgraph->n_nodes) {
9713-
return false;
9714-
}
9715-
ggml_tensor *mul = cgraph->nodes[i + 1];
9716-
if (mul->op != GGML_OP_MUL || mul->src[0] != norm) {
9717-
return false;
9718-
}
9719-
9720-
// Since norm is the first operand of mul, it must be the same shape
9721-
GGML_ASSERT(ggml_are_same_shape(mul, norm));
9722-
9723-
return true;
9724-
}
9725-
97269700
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
97279701
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
97289702
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9736,7 +9710,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97369710

97379711
uint64_t total_mat_mul_bytes = 0;
97389712
for (int i = 0; i < cgraph->n_nodes; i++) {
9739-
if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9713+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
97409714
ctx->num_additional_fused_ops = 1;
97419715
}
97429716
ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
@@ -9806,7 +9780,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98069780
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
98079781
}
98089782

9809-
if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9783+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
98109784
ctx->num_additional_fused_ops = 1;
98119785
}
98129786

0 commit comments

Comments
 (0)