Skip to content

Commit c2082d9

Browse files
PABannierggerganov
authored andcommitted
ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)
* ggml_pad_reflect_1d defined in header * implemented on CPU * called the forward pass * impl Metal kernel * added Metal kernel * added OP_PAD_REFLECT_1D in test-backend-ops.cpp * add test-pad-reflect-1d test case * test case support multiple backend
1 parent d405804 commit c2082d9

File tree

6 files changed

+192
-2
lines changed

6 files changed

+192
-2
lines changed

ggml/include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ extern "C" {
499499
GGML_OP_POOL_2D_BACK,
500500
GGML_OP_UPSCALE, // nearest interpolate
501501
GGML_OP_PAD,
502+
GGML_OP_PAD_REFLECT_1D,
502503
GGML_OP_ARANGE,
503504
GGML_OP_TIMESTEP_EMBEDDING,
504505
GGML_OP_ARGSORT,
@@ -1695,6 +1696,13 @@ extern "C" {
16951696
int p2,
16961697
int p3);
16971698

1699+
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
1700+
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
1701+
struct ggml_context * ctx,
1702+
struct ggml_tensor * a,
1703+
int p0,
1704+
int p1);
1705+
16981706
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
16991707
// timesteps: [N,]
17001708
// return: [N, dim]

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10439,6 +10439,40 @@ static void ggml_compute_forward_pad(
1043910439
}
1044010440
}
1044110441

10442+
// ggml_compute_forward_pad_reflect_1d
10443+
10444+
static void ggml_compute_forward_pad_reflect_1d(
10445+
const struct ggml_compute_params * params,
10446+
struct ggml_tensor * dst) {
10447+
10448+
const struct ggml_tensor * src0 = dst->src[0];
10449+
10450+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
10451+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
10452+
10453+
const int ith = params->ith;
10454+
const int nth = params->nth;
10455+
10456+
const int32_t * opts = (const int32_t *) dst->op_params;
10457+
const int p0 = opts[0];
10458+
const int p1 = opts[1];
10459+
10460+
GGML_TENSOR_UNARY_OP_LOCALS
10461+
10462+
for (int64_t i3 = 0; i3 < ne3; i3++) {
10463+
for (int64_t i2 = 0; i2 < ne2; i2++) {
10464+
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
10465+
float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
10466+
float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
10467+
10468+
ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
10469+
10470+
for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
10471+
for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
10472+
}
10473+
}
10474+
}
10475+
}
1044210476

1044310477
// ggml_compute_forward_arange
1044410478

@@ -12535,6 +12569,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1253512569
{
1253612570
ggml_compute_forward_pad(params, tensor);
1253712571
} break;
12572+
case GGML_OP_PAD_REFLECT_1D:
12573+
{
12574+
ggml_compute_forward_pad_reflect_1d(params, tensor);
12575+
} break;
1253812576
case GGML_OP_ARANGE:
1253912577
{
1254012578
ggml_compute_forward_arange(params, tensor);
@@ -12877,6 +12915,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1287712915
} break;
1287812916
case GGML_OP_UPSCALE:
1287912917
case GGML_OP_PAD:
12918+
case GGML_OP_PAD_REFLECT_1D:
1288012919
case GGML_OP_ARANGE:
1288112920
case GGML_OP_TIMESTEP_EMBEDDING:
1288212921
case GGML_OP_ARGSORT:

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
310310
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
311311
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
312312
GGML_METAL_KERNEL_TYPE_PAD_F32,
313+
GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
313314
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
314315
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
315316
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -877,6 +878,7 @@ @implementation GGMLMetalClass
877878
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
878879
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
879880
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
881+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
880882
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
881883
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
882884
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10991101
case GGML_OP_POOL_2D:
11001102
case GGML_OP_UPSCALE:
11011103
case GGML_OP_PAD:
1104+
case GGML_OP_PAD_REFLECT_1D:
11021105
case GGML_OP_ARANGE:
11031106
case GGML_OP_TIMESTEP_EMBEDDING:
11041107
case GGML_OP_ARGSORT:
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
32583261

32593262
const int nth = MIN(1024, ne0);
32603263

3264+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
3265+
} break;
3266+
case GGML_OP_PAD_REFLECT_1D:
3267+
{
3268+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
3269+
3270+
const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
3271+
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
3272+
3273+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
3274+
3275+
[encoder setComputePipelineState:pipeline];
3276+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3277+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3278+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
3279+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
3280+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
3281+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
3282+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
3283+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
3284+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
3285+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
3286+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
3287+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
3288+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
3289+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
3290+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
3291+
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
3292+
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
3293+
3294+
const int nth = MIN(1024, ne0);
3295+
32613296
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
32623297
} break;
32633298
case GGML_OP_ARANGE:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
28972897
}
28982898
}
28992899

2900+
kernel void kernel_pad_reflect_1d_f32(
2901+
device const char * src0,
2902+
device char * dst,
2903+
constant int64_t & ne00,
2904+
constant int64_t & ne01,
2905+
constant int64_t & ne02,
2906+
constant int64_t & ne03,
2907+
constant int64_t & ne0,
2908+
constant uint64_t & nb00,
2909+
constant uint64_t & nb01,
2910+
constant uint64_t & nb02,
2911+
constant uint64_t & nb03,
2912+
constant uint64_t & nb0,
2913+
constant uint64_t & nb1,
2914+
constant uint64_t & nb2,
2915+
constant uint64_t & nb3,
2916+
constant int32_t & p0,
2917+
constant int32_t & p1,
2918+
uint3 tgpig[[threadgroup_position_in_grid]],
2919+
uint3 tgpg[[threadgroups_per_grid]],
2920+
uint3 tpitg[[thread_position_in_threadgroup]],
2921+
uint3 ntg[[threads_per_threadgroup]]) {
2922+
2923+
const int64_t i3 = tgpig.z;
2924+
const int64_t i2 = tgpig.y;
2925+
const int64_t i1 = tgpig.x;
2926+
2927+
const int64_t i03 = i3;
2928+
const int64_t i02 = i2;
2929+
const int64_t i01 = i1;
2930+
2931+
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
2932+
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
2933+
2934+
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
2935+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
2936+
if (i0 < p0) {
2937+
dst_ptr[i0] = src0_ptr[p0 - i0];
2938+
} else if (i0 < ne0 - p1) {
2939+
dst_ptr[i0] = src0_ptr[i0 - p0];
2940+
} else {
2941+
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
2942+
}
2943+
}
2944+
}
2945+
}
2946+
29002947
kernel void kernel_arange_f32(
29012948
device char * dst,
29022949
constant int64_t & ne0,

ggml/src/ggml.c

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
950950
"POOL_2D_BACK",
951951
"UPSCALE",
952952
"PAD",
953+
"PAD_REFLECT_1D",
953954
"ARANGE",
954955
"TIMESTEP_EMBEDDING",
955956
"ARGSORT",
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
983984
"OPT_STEP_ADAMW",
984985
};
985986

986-
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
987+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
987988

988989
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
989990
"none",
@@ -1045,6 +1046,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10451046
"pool_2d_back(x)",
10461047
"upscale(x)",
10471048
"pad(x)",
1049+
"pad_reflect_1d(x)",
10481050
"arange(start, stop, step)",
10491051
"timestep_embedding(timesteps, dim, max_period)",
10501052
"argsort(x)",
@@ -1078,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10781080
"adamw(x)",
10791081
};
10801082

1081-
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
1083+
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
10821084

10831085
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10841086

@@ -4097,6 +4099,37 @@ struct ggml_tensor * ggml_pad(
40974099
return result;
40984100
}
40994101

4102+
// ggml_pad_reflect_1d
4103+
4104+
struct ggml_tensor * ggml_pad_reflect_1d(
4105+
struct ggml_context * ctx,
4106+
struct ggml_tensor * a,
4107+
int p0,
4108+
int p1) {
4109+
GGML_ASSERT(p0 >= 0);
4110+
GGML_ASSERT(p1 >= 0);
4111+
4112+
GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
4113+
GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
4114+
4115+
GGML_ASSERT(ggml_is_contiguous(a));
4116+
GGML_ASSERT(a->type == GGML_TYPE_F32);
4117+
4118+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
4119+
a->ne[0] + p0 + p1,
4120+
a->ne[1],
4121+
a->ne[2],
4122+
a->ne[3]);
4123+
4124+
int32_t params[] = { p0, p1 };
4125+
ggml_set_op_params(result, params, sizeof(params));
4126+
4127+
result->op = GGML_OP_PAD_REFLECT_1D;
4128+
result->src[0] = a;
4129+
4130+
return result;
4131+
}
4132+
41004133
// ggml_arange
41014134

41024135
struct ggml_tensor * ggml_arange(

tests/test-backend-ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,6 +2697,33 @@ struct test_pad : public test_case {
26972697
}
26982698
};
26992699

2700+
// GGML_OP_PAD_REFLECT_1D
2701+
struct test_pad_reflect_1d : public test_case {
2702+
const ggml_type type;
2703+
const std::array<int64_t, 4> ne_a;
2704+
const int pad_0;
2705+
const int pad_1;
2706+
2707+
std::string vars() override {
2708+
return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
2709+
}
2710+
2711+
test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
2712+
std::array<int64_t, 4> ne_a = {512, 34, 2, 1},
2713+
int pad_0 = 10, int pad_1 = 9)
2714+
: type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {}
2715+
2716+
ggml_tensor * build_graph(ggml_context * ctx) override {
2717+
ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
2718+
ggml_set_name(a, "a");
2719+
2720+
ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
2721+
ggml_set_name(out, "out");
2722+
2723+
return out;
2724+
}
2725+
};
2726+
27002727
// GGML_OP_ARANGE
27012728
struct test_arange : public test_case {
27022729
const ggml_type type;
@@ -3816,6 +3843,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38163843
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
38173844
test_cases.emplace_back(new test_acc());
38183845
test_cases.emplace_back(new test_pad());
3846+
test_cases.emplace_back(new test_pad_reflect_1d());
38193847
test_cases.emplace_back(new test_arange());
38203848
test_cases.emplace_back(new test_timestep_embedding());
38213849
test_cases.emplace_back(new test_leaky_relu());

0 commit comments

Comments
 (0)