Skip to content

Commit 2d5dd7b

Browse files
authored
ggml : add epsilon as a parameter for group_norm (#8818)
Signed-off-by: Molly Sophia <[email protected]>
1 parent cdd1889 commit 2d5dd7b

File tree

7 files changed

+38
-24
lines changed

7 files changed

+38
-24
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,16 +1140,17 @@ extern "C" {
11401140

11411141
// group normalize along ne0*ne1*n_groups
11421142
// used in stable-diffusion
1143-
// TODO: eps is hardcoded to 1e-6 for now
11441143
GGML_API struct ggml_tensor * ggml_group_norm(
11451144
struct ggml_context * ctx,
11461145
struct ggml_tensor * a,
1147-
int n_groups);
1146+
int n_groups,
1147+
float eps);
11481148

11491149
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
11501150
struct ggml_context * ctx,
11511151
struct ggml_tensor * a,
1152-
int n_groups);
1152+
int n_groups,
1153+
float eps);
11531154

11541155
// a - x
11551156
// b - dy

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,11 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
464464
aclTensor* acl_src = ggml_cann_create_tensor(src);
465465
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
466466

467-
const float eps = 1e-6f; // TODO: make this a parameter
468467
int n_groups = dst->op_params[0];
469468

469+
float eps;
470+
memcpy(&eps, dst->op_params + 1, sizeof(float));
471+
470472
uint64_t workspaceSize = 0;
471473
aclOpExecutor* executor;
472474
void* workspaceAddr = nullptr;

ggml/src/ggml-cuda/norm.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
142142
}
143143
}
144144

145-
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
146-
static const float eps = 1e-6f;
145+
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
147146
if (group_size < 1024) {
148147
const dim3 block_dims(WARP_SIZE, 1, 1);
149148
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
196195
GGML_ASSERT( dst->type == GGML_TYPE_F32);
197196

198197
int num_groups = dst->op_params[0];
198+
199+
float eps;
200+
memcpy(&eps, dst->op_params + 1, sizeof(float));
201+
199202
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
200-
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
203+
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
201204
}
202205

203206
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-metal.m

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,10 +2229,8 @@ static enum ggml_status ggml_metal_graph_compute(
22292229
GGML_ASSERT(ne00 % 4 == 0);
22302230
GGML_ASSERT(ggml_is_contiguous(src0));
22312231

2232-
//float eps;
2233-
//memcpy(&eps, dst->op_params, sizeof(float));
2234-
2235-
const float eps = 1e-6f; // TODO: temporarily hardcoded
2232+
float eps;
2233+
memcpy(&eps, dst->op_params + 1, sizeof(float));
22362234

22372235
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
22382236

ggml/src/ggml-sycl/norm.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
225225
}
226226

227227
static void group_norm_f32_sycl(const float* x, float* dst,
228-
const int num_groups, const int group_size,
228+
const int num_groups, const float eps, const int group_size,
229229
const int ne_elements, queue_ptr stream, int device) {
230-
static const float eps = 1e-6f;
231230
if (group_size < 1024) {
232231
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233232
stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
343342
GGML_ASSERT(dst->type == GGML_TYPE_F32);
344343

345344
int num_groups = dst->op_params[0];
345+
346+
float eps;
347+
memcpy(&eps, dst->op_params + 1, sizeof(float));
348+
346349
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347-
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
350+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348351

349352
(void)src1;
350353
(void)dst;

ggml/src/ggml.c

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5374,6 +5374,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
53745374
struct ggml_context * ctx,
53755375
struct ggml_tensor * a,
53765376
int n_groups,
5377+
float eps,
53775378
bool inplace) {
53785379

53795380
bool is_node = false;
@@ -5384,7 +5385,8 @@ static struct ggml_tensor * ggml_group_norm_impl(
53845385

53855386
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
53865387

5387-
result->op_params[0] = n_groups;
5388+
ggml_set_op_params_i32(result, 0, n_groups);
5389+
ggml_set_op_params_f32(result, 1, eps);
53885390

53895391
result->op = GGML_OP_GROUP_NORM;
53905392
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5396,15 +5398,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
53965398
struct ggml_tensor * ggml_group_norm(
53975399
struct ggml_context * ctx,
53985400
struct ggml_tensor * a,
5399-
int n_groups) {
5400-
return ggml_group_norm_impl(ctx, a, n_groups, false);
5401+
int n_groups,
5402+
float eps) {
5403+
return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
54015404
}
54025405

54035406
struct ggml_tensor * ggml_group_norm_inplace(
54045407
struct ggml_context * ctx,
54055408
struct ggml_tensor * a,
5406-
int n_groups) {
5407-
return ggml_group_norm_impl(ctx, a, n_groups, true);
5409+
int n_groups,
5410+
float eps) {
5411+
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
54085412
}
54095413

54105414
// ggml_mul_mat
@@ -12095,10 +12099,11 @@ static void ggml_compute_forward_group_norm_f32(
1209512099

1209612100
GGML_TENSOR_UNARY_OP_LOCALS
1209712101

12098-
const float eps = 1e-6f; // TODO: make this a parameter
12099-
1210012102
// TODO: optimize
1210112103

12104+
float eps;
12105+
memcpy(&eps, dst->op_params + 1, sizeof(float));
12106+
1210212107
int n_channels = src0->ne[2];
1210312108
int n_groups = dst->op_params[0];
1210412109
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;

tests/test-backend-ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,19 +1511,21 @@ struct test_group_norm : public test_case {
15111511
const ggml_type type;
15121512
const std::array<int64_t, 4> ne;
15131513
const int32_t num_groups;
1514+
const float eps;
15141515

15151516
std::string vars() override {
15161517
return VARS_TO_STR3(type, ne, num_groups);
15171518
}
15181519

15191520
test_group_norm(ggml_type type = GGML_TYPE_F32,
15201521
std::array<int64_t, 4> ne = {64, 64, 320, 1},
1521-
int32_t num_groups = 32)
1522-
: type(type), ne(ne), num_groups(num_groups) {}
1522+
int32_t num_groups = 32,
1523+
float eps = 1e-6f)
1524+
: type(type), ne(ne), num_groups(num_groups), eps(eps) {}
15231525

15241526
ggml_tensor * build_graph(ggml_context * ctx) override {
15251527
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1526-
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
1528+
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
15271529
return out;
15281530
}
15291531
};

0 commit comments

Comments
 (0)