Skip to content

Commit 6b20811

Browse files
committed
ggml: Add epsilon as a parameter for group_norm
Signed-off-by: Molly Sophia <[email protected]>
1 parent 0fbbd88 commit 6b20811

File tree

7 files changed

+25
-21
lines changed

7 files changed

+25
-21
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,16 +1139,17 @@ extern "C" {
11391139

11401140
// group normalize along ne0*ne1*n_groups
11411141
// used in stable-diffusion
1142-
// TODO: eps is hardcoded to 1e-6 for now
11431142
GGML_API struct ggml_tensor * ggml_group_norm(
11441143
struct ggml_context * ctx,
11451144
struct ggml_tensor * a,
1146-
int n_groups);
1145+
int n_groups,
1146+
float eps);
11471147

11481148
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
11491149
struct ggml_context * ctx,
11501150
struct ggml_tensor * a,
1151-
int n_groups);
1151+
int n_groups,
1152+
float eps);
11521153

11531154
// a - x
11541155
// b - dy

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,8 +464,8 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
464464
aclTensor* acl_src = ggml_cann_create_tensor(src);
465465
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
466466

467-
const float eps = 1e-6f; // TODO: make this a parameter
468467
int n_groups = dst->op_params[0];
468+
const float eps = *(float*)(dst->op_params + 1);
469469

470470
uint64_t workspaceSize = 0;
471471
aclOpExecutor* executor;

ggml/src/ggml-cuda/norm.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
142142
}
143143
}
144144

145-
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
146-
static const float eps = 1e-6f;
145+
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
147146
if (group_size < 1024) {
148147
const dim3 block_dims(WARP_SIZE, 1, 1);
149148
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
@@ -196,8 +195,9 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
196195
GGML_ASSERT( dst->type == GGML_TYPE_F32);
197196

198197
int num_groups = dst->op_params[0];
198+
float eps = *(float*)(dst->op_params + 1);
199199
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
200-
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
200+
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
201201
}
202202

203203
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-metal.m

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,10 +2229,7 @@ static enum ggml_status ggml_metal_graph_compute(
22292229
GGML_ASSERT(ne00 % 4 == 0);
22302230
GGML_ASSERT(ggml_is_contiguous(src0));
22312231

2232-
//float eps;
2233-
//memcpy(&eps, dst->op_params, sizeof(float));
2234-
2235-
const float eps = 1e-6f; // TODO: temporarily hardcoded
2232+
const float eps = *(float*)(dst->op_params + 1);
22362233

22372234
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
22382235

ggml/src/ggml-sycl/norm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
225225
}
226226

227227
static void group_norm_f32_sycl(const float* x, float* dst,
228-
const int num_groups, const int group_size,
228+
const int num_groups, const float eps, const int group_size,
229229
const int ne_elements, queue_ptr stream, int device) {
230-
static const float eps = 1e-6f;
231230
if (group_size < 1024) {
232231
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
233232
stream->submit([&](sycl::handler& cgh) {
@@ -343,8 +342,9 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
343342
GGML_ASSERT(dst->type == GGML_TYPE_F32);
344343

345344
int num_groups = dst->op_params[0];
345+
float eps = *(float*)(dst->op_params + 1);
346346
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
347-
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
347+
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
348348

349349
(void)src1;
350350
(void)dst;

ggml/src/ggml.c

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5358,6 +5358,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
53585358
struct ggml_context * ctx,
53595359
struct ggml_tensor * a,
53605360
int n_groups,
5361+
float eps,
53615362
bool inplace) {
53625363

53635364
bool is_node = false;
@@ -5369,6 +5370,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
53695370
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
53705371

53715372
result->op_params[0] = n_groups;
5373+
*(float*)(result->op_params + 1) = eps;
53725374

53735375
result->op = GGML_OP_GROUP_NORM;
53745376
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@@ -5380,15 +5382,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
53805382
struct ggml_tensor * ggml_group_norm(
53815383
struct ggml_context * ctx,
53825384
struct ggml_tensor * a,
5383-
int n_groups) {
5384-
return ggml_group_norm_impl(ctx, a, n_groups, false);
5385+
int n_groups,
5386+
float eps) {
5387+
return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
53855388
}
53865389

53875390
struct ggml_tensor * ggml_group_norm_inplace(
53885391
struct ggml_context * ctx,
53895392
struct ggml_tensor * a,
5390-
int n_groups) {
5391-
return ggml_group_norm_impl(ctx, a, n_groups, true);
5393+
int n_groups,
5394+
float eps) {
5395+
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
53925396
}
53935397

53945398
// ggml_mul_mat

tests/test-backend-ops.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,19 +1511,21 @@ struct test_group_norm : public test_case {
15111511
const ggml_type type;
15121512
const std::array<int64_t, 4> ne;
15131513
const int32_t num_groups;
1514+
const float eps;
15141515

15151516
std::string vars() override {
15161517
return VARS_TO_STR3(type, ne, num_groups);
15171518
}
15181519

15191520
test_group_norm(ggml_type type = GGML_TYPE_F32,
15201521
std::array<int64_t, 4> ne = {64, 64, 320, 1},
1521-
int32_t num_groups = 32)
1522-
: type(type), ne(ne), num_groups(num_groups) {}
1522+
int32_t num_groups = 32,
1523+
float eps = 1e-6f)
1524+
: type(type), ne(ne), num_groups(num_groups), eps(eps) {}
15231525

15241526
ggml_tensor * build_graph(ggml_context * ctx) override {
15251527
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1526-
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
1528+
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
15271529
return out;
15281530
}
15291531
};

0 commit comments

Comments
 (0)