Skip to content

Commit c5d3792

Browse files
committed
extract some common fusion logic
1 parent 1129a5b commit c5d3792

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

ggml/src/ggml-impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,28 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
589589
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
590590
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
591591

592+
// return true if the node's results are only used by N other nodes
593+
// and can be fused into their calculations.
594+
static inline bool ggml_can_fuse_node(const struct ggml_tensor * node, int32_t N) {
595+
// check the use count against how many we're replacing
596+
if (node->use_count != N) {
597+
return false;
598+
}
599+
600+
// if node is a view, some other node might be using the intermediate result
601+
// via the view source.
602+
if (node->view_src) {
603+
return false;
604+
}
605+
606+
// If the user requested output for the node, can't fuse
607+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
608+
return false;
609+
}
610+
611+
return true;
612+
}
613+
592614
#ifdef __cplusplus
593615
}
594616
#endif

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9698,15 +9698,14 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96989698
}
96999699

97009700
// Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701-
bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9701+
static bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
97029702
ggml_tensor *norm = cgraph->nodes[i];
97039703

9704-
if (norm->op != GGML_OP_RMS_NORM || norm->use_count != 1) {
9704+
if (norm->op != GGML_OP_RMS_NORM) {
97059705
return false;
97069706
}
9707-
// if norm is a view, some other node might be using the intermediate result
9708-
// view the view source.
9709-
if (norm->view_src) {
9707+
9708+
if (!ggml_can_fuse_node(norm, 1)) {
97109709
return false;
97119710
}
97129711

@@ -9721,7 +9720,6 @@ bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97219720
// Since norm is the first operand of mul, it must be the same shape
97229721
GGML_ASSERT(ggml_are_same_shape(mul, norm));
97239722

9724-
// XXX TODO: Do we need a way to indicate that the user doesn't need the intermediate result?
97259723
return true;
97269724
}
97279725

0 commit comments

Comments
 (0)