Skip to content

Commit 1129a5b

Browse files
committed
vulkan: Add fusion support for RMS_NORM+MUL
- Add a use_count to ggml_tensor, so we can detect if an output is used more than once. - Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor. - Add detection logic and basic fusion logic in ggml-vulkan. - Add some testing support for fusion. Rather than computing one node at a time, allow for computing the whole graph and just testing one node's results. Add rms_norm_mul tests and enable a llama test.
1 parent 73e53dc commit 1129a5b

File tree

8 files changed

+196
-42
lines changed

8 files changed

+196
-42
lines changed

ggml/include/ggml-backend.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ extern "C" {
339339
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
340340

341341
// Compare the output of two backends
342-
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
342+
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor *test_node);
343343

344344
// Tensor initialization
345345
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

ggml/include/ggml.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,11 @@ extern "C" {
604604

605605
void * extra; // extra things e.g. for ggml-cuda.cu
606606

607-
char padding[8];
607+
// number of operations that use this tensor as a src
608+
int32_t use_count;
609+
610+
// add padding if needed to make a multiple of GGML_MEM_ALIGN
611+
char padding[4];
608612
};
609613

610614
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml-backend.cpp

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,8 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
817817
}
818818
if (sched->debug > 1) {
819819
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820-
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
821-
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
820+
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821+
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), node->use_count);
822822
for (int j = 0; j < GGML_MAX_SRC; j++) {
823823
struct ggml_tensor * src = node->src[j];
824824
if (src == NULL) {
@@ -1826,7 +1826,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
18261826
ggml_free(copy.ctx_unallocated);
18271827
}
18281828

1829-
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
1829+
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor *test_node) {
18301830
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
18311831
if (copy.buffer == NULL) {
18321832
return false;
@@ -1837,28 +1837,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
18371837

18381838
assert(g1->n_nodes == g2->n_nodes);
18391839

1840-
for (int i = 0; i < g1->n_nodes; i++) {
1841-
struct ggml_tensor * t1 = g1->nodes[i];
1842-
struct ggml_tensor * t2 = g2->nodes[i];
1840+
if (test_node != nullptr) {
1841+
// Compute the whole graph and only test the output for a specific tensor
1842+
ggml_backend_graph_compute(backend1, g1);
1843+
ggml_backend_graph_compute(backend2, g2);
18431844

1844-
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
1845+
int test_node_idx = -1;
1846+
for (int i = 0; i < g1->n_nodes; i++) {
1847+
struct ggml_tensor * t1 = g1->nodes[i];
1848+
if (t1 == test_node) {
1849+
test_node_idx = i;
1850+
break;
1851+
}
1852+
}
1853+
GGML_ASSERT(test_node_idx != -1);
18451854

1846-
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1847-
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
1855+
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
1856+
} else {
1857+
for (int i = 0; i < g1->n_nodes; i++) {
1858+
struct ggml_tensor * t1 = g1->nodes[i];
1859+
struct ggml_tensor * t2 = g2->nodes[i];
18481860

1849-
ggml_backend_graph_compute(backend1, &g1v);
1850-
ggml_backend_graph_compute(backend2, &g2v);
1861+
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
18511862

1852-
if (ggml_is_view_op(t1->op)) {
1853-
continue;
1854-
}
1863+
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
1864+
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
18551865

1856-
// compare results, calculate rms etc
1857-
if (!callback(i, t1, t2, user_data)) {
1858-
break;
1866+
ggml_backend_graph_compute(backend1, &g1v);
1867+
ggml_backend_graph_compute(backend2, &g2v);
1868+
1869+
if (ggml_is_view_op(t1->op)) {
1870+
continue;
1871+
}
1872+
1873+
// compare results, calculate rms etc
1874+
if (!callback(i, t1, t2, user_data)) {
1875+
break;
1876+
}
18591877
}
18601878
}
1861-
18621879
ggml_backend_graph_copy_free(copy);
18631880

18641881
return true;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ struct vk_device_struct {
425425
vk_pipeline pipeline_norm_f32;
426426
vk_pipeline pipeline_group_norm_f32;
427427
vk_pipeline pipeline_rms_norm_f32;
428+
vk_pipeline pipeline_rms_norm_mul_f32;
428429
vk_pipeline pipeline_rms_norm_back_f32;
429430
vk_pipeline pipeline_l2_norm_f32;
430431

@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
978979

979980
vk_command_pool compute_cmd_pool;
980981
vk_command_pool transfer_cmd_pool;
982+
983+
// number of additional consecutive nodes that are being fused with the
984+
// node currently being processed
985+
bool num_additional_fused_ops {};
981986
};
982987

983988
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
26552660

26562661
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26572662
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2658-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2663+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
2664+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
26592665
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26602666
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
26612667

@@ -6418,7 +6424,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
64186424
return nullptr;
64196425
case GGML_OP_RMS_NORM:
64206426
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6421-
return ctx->device->pipeline_rms_norm_f32;
6427+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
64226428
}
64236429
return nullptr;
64246430
case GGML_OP_RMS_NORM_BACK:
@@ -7518,18 +7524,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
75187524
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
75197525
}
75207526

7521-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7527+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
75227528
float * op_params = (float *)dst->op_params;
75237529
const uint32_t src0_type_size = ggml_type_size(src0->type);
7530+
const uint32_t src1_type_size = ggml_type_size(src1->type);
75247531
const uint32_t dst_type_size = ggml_type_size(dst->type);
75257532

7526-
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7533+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
75277534
(uint32_t)ggml_nelements(src0),
7528-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7529-
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7535+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7536+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
7537+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
75307538
0,
7531-
op_params[0], 0.0f,
7532-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7539+
op_params[0], 0.0f, 0,
75337540
}, dryrun);
75347541
}
75357542

@@ -8724,7 +8731,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
87248731

87258732
// Returns true if node has enqueued work into the queue, false otherwise
87268733
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
8727-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
8734+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
87288735
if (ggml_is_empty(node) || !node->buffer) {
87298736
return false;
87308737
}
@@ -8962,8 +8969,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
89628969

89638970
break;
89648971
case GGML_OP_RMS_NORM:
8965-
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
8966-
8972+
if (ctx->num_additional_fused_ops > 0) {
8973+
// fused rms_norm + mul
8974+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
8975+
ggml_vk_rms_norm(ctx, compute_ctx, src0, mul->src[1], mul, dryrun);
8976+
} else {
8977+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
8978+
}
89678979
break;
89688980
case GGML_OP_RMS_NORM_BACK:
89698981
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9685,6 +9697,34 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
96859697
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
96869698
}
96879699

9700+
// Returns true if nodes [i, i+1] are fusable RMS_NORM + MUL.
9701+
bool ggml_can_fuse_rms_norm_mul(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int i) {
9702+
ggml_tensor *norm = cgraph->nodes[i];
9703+
9704+
if (norm->op != GGML_OP_RMS_NORM || norm->use_count != 1) {
9705+
return false;
9706+
}
9707+
// if norm is a view, some other node might be using the intermediate result
9708+
// view the view source.
9709+
if (norm->view_src) {
9710+
return false;
9711+
}
9712+
9713+
if (i + 1 >= cgraph->n_nodes) {
9714+
return false;
9715+
}
9716+
ggml_tensor *mul = cgraph->nodes[i + 1];
9717+
if (mul->op != GGML_OP_MUL || mul->src[0] != norm) {
9718+
return false;
9719+
}
9720+
9721+
// Since norm is the first operand of mul, it must be the same shape
9722+
GGML_ASSERT(ggml_are_same_shape(mul, norm));
9723+
9724+
// XXX TODO: Do we need a way to indicate that the user doesn't need the intermediate result?
9725+
return true;
9726+
}
9727+
96889728
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
96899729
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
96909730
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9698,10 +9738,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
96989738

96999739
uint64_t total_mat_mul_bytes = 0;
97009740
for (int i = 0; i < cgraph->n_nodes; i++) {
9701-
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
9741+
if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9742+
ctx->num_additional_fused_ops = 1;
9743+
}
9744+
ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
97029745
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
97039746
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97049747
}
9748+
i += ctx->num_additional_fused_ops;
9749+
ctx->num_additional_fused_ops = 0;
97059750
}
97069751
if (ctx->device->need_compiles) {
97079752
ggml_vk_load_shaders(ctx->device);
@@ -9763,14 +9808,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97639808
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
97649809
}
97659810

9811+
if (ggml_can_fuse_rms_norm_mul(ctx, cgraph, i)) {
9812+
ctx->num_additional_fused_ops = 1;
9813+
}
9814+
97669815
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
97679816
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
97689817
bool submit = (submitted_nodes >= nodes_per_submit) ||
97699818
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
9770-
(i == last_node) ||
9819+
(i + ctx->num_additional_fused_ops == last_node) ||
97719820
(almost_ready && !ctx->almost_ready_fence_pending);
97729821

9773-
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
9822+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
97749823

97759824
if (vk_perf_logger_enabled) {
97769825
if (ctx->compute_ctx.expired()) {
@@ -9780,7 +9829,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
97809829
} else {
97819830
compute_ctx = ctx->compute_ctx.lock();
97829831
}
9783-
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
9832+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
9833+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
9834+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
9835+
}
97849836
}
97859837

97869838
if (enqueued) {
@@ -9802,6 +9854,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
98029854
}
98039855
submit_count++;
98049856
}
9857+
i += ctx->num_additional_fused_ops;
9858+
ctx->num_additional_fused_ops = 0;
98059859
}
98069860

98079861
if (vk_perf_logger_enabled) {

ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#version 450
22

3-
#include "generic_unary_head.comp"
3+
#include "generic_binary_head.comp"
44
#include "types.comp"
55

66
#extension GL_EXT_control_flow_attributes : enable
77
#define BLOCK_SIZE 512
88

9+
layout (constant_id = 1) const bool do_multiply = false;
10+
911
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
1012

1113
shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
2527
const uint stride_sample = p.nb03;
2628

2729
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
30+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
2831
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
2932

3033
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
4649
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
4750
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
4851

49-
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
50-
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
52+
if (do_multiply) {
53+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
54+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
55+
}
56+
} else {
57+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
58+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
59+
}
5160
}
5261
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ void process_shaders() {
497497
// Norms
498498
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
499499
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500-
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
500+
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
501501
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
502502
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
503503

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,6 +1619,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
16191619
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
16201620
/*.name =*/ { 0 },
16211621
/*.extra =*/ NULL,
1622+
/*.use_count =*/ 0,
16221623
/*.padding =*/ { 0 },
16231624
};
16241625

@@ -5828,6 +5829,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58285829
/* unknown order, just fall back to using i*/ i;
58295830
if (node->src[k]) {
58305831
ggml_visit_parents(cgraph, node->src[k]);
5832+
node->src[k]->use_count++;
58315833
}
58325834
}
58335835

0 commit comments

Comments
 (0)