Skip to content

ggml: Add epsilon as a parameter for group_norm #8818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1139,16 +1139,17 @@ extern "C" {

// group normalize along ne0*ne1*n_groups
// used in stable-diffusion
// TODO: eps is hardcoded to 1e-6 for now
GGML_API struct ggml_tensor * ggml_group_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups);
int n_groups,
float eps);

GGML_API struct ggml_tensor * ggml_group_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups);
int n_groups,
float eps);

// a - x
// b - dy
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-cann/aclnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,11 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
aclTensor* acl_src = ggml_cann_create_tensor(src);
aclTensor* acl_dst = ggml_cann_create_tensor(dst);

const float eps = 1e-6f; // TODO: make this a parameter
int n_groups = dst->op_params[0];

float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));

uint64_t workspaceSize = 0;
aclOpExecutor* executor;
void* workspaceAddr = nullptr;
Expand Down
9 changes: 6 additions & 3 deletions ggml/src/ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
}
}

static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
static const float eps = 1e-6f;
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
Expand Down Expand Up @@ -196,8 +195,12 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ASSERT( dst->type == GGML_TYPE_F32);

int num_groups = dst->op_params[0];

float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));

int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
}

void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand Down
6 changes: 2 additions & 4 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2229,10 +2229,8 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0));

//float eps;
//memcpy(&eps, dst->op_params, sizeof(float));

const float eps = 1e-6f; // TODO: temporarily hardcoded
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));

const int32_t n_groups = ((int32_t *) dst->op_params)[0];

Expand Down
9 changes: 6 additions & 3 deletions ggml/src/ggml-sycl/norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,8 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
}

static void group_norm_f32_sycl(const float* x, float* dst,
const int num_groups, const int group_size,
const int num_groups, const float eps, const int group_size,
const int ne_elements, queue_ptr stream, int device) {
static const float eps = 1e-6f;
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
Expand Down Expand Up @@ -343,8 +342,12 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
GGML_ASSERT(dst->type == GGML_TYPE_F32);

int num_groups = dst->op_params[0];

float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));

int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);

(void)src1;
(void)dst;
Expand Down
19 changes: 12 additions & 7 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5358,6 +5358,7 @@ static struct ggml_tensor * ggml_group_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups,
float eps,
bool inplace) {

bool is_node = false;
Expand All @@ -5368,7 +5369,8 @@ static struct ggml_tensor * ggml_group_norm_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

result->op_params[0] = n_groups;
ggml_set_op_params_i32(result, 0, n_groups);
ggml_set_op_params_f32(result, 1, eps);

result->op = GGML_OP_GROUP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand All @@ -5380,15 +5382,17 @@ static struct ggml_tensor * ggml_group_norm_impl(
struct ggml_tensor * ggml_group_norm(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups) {
return ggml_group_norm_impl(ctx, a, n_groups, false);
int n_groups,
float eps) {
return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Side question) Is ggml_group_norm equivalent to ggml_norm with ggml_reshape to group rows together beforehand and split them back afterwards?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess yes?

}

struct ggml_tensor * ggml_group_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_groups) {
return ggml_group_norm_impl(ctx, a, n_groups, true);
int n_groups,
float eps) {
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
}

// ggml_mul_mat
Expand Down Expand Up @@ -12079,10 +12083,11 @@ static void ggml_compute_forward_group_norm_f32(

GGML_TENSOR_UNARY_OP_LOCALS

const float eps = 1e-6f; // TODO: make this a parameter

// TODO: optimize

float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));

int n_channels = src0->ne[2];
int n_groups = dst->op_params[0];
int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
Expand Down
8 changes: 5 additions & 3 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1511,19 +1511,21 @@ struct test_group_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const int32_t num_groups;
const float eps;

std::string vars() override {
return VARS_TO_STR3(type, ne, num_groups);
}

test_group_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 64, 320, 1},
int32_t num_groups = 32)
: type(type), ne(ne), num_groups(num_groups) {}
int32_t num_groups = 32,
float eps = 1e-6f)
: type(type), ne(ne), num_groups(num_groups), eps(eps) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
ggml_tensor * out = ggml_group_norm(ctx, a, num_groups, eps);
return out;
}
};
Expand Down
Loading