Skip to content

Commit 4c9388f

Browse files
authored
metal : add POOL2D and fix IM2COL (#9943)
* add pool_2d Signed-off-by: Junhee Yoo <[email protected]> * fix im2col and add unittest for N>=1024 Signed-off-by: Junhee Yoo <[email protected]> * add tests for N % 1024 != 0 Signed-off-by: Junhee Yoo <[email protected]> * remove trailing whitespaces Signed-off-by: Junhee Yoo <[email protected]> * apply suggestions Signed-off-by: Junhee Yoo <[email protected]> * apply more optimization - original IM2COL kernel + _ext with MIN() Signed-off-by: Junhee Yoo <[email protected]> * apply review: change kernel name of pool_2d Signed-off-by: Junhee Yoo <[email protected]> * apply review Signed-off-by: Junhee Yoo <[email protected]> * fix more formatting and enhance readability Signed-off-by: Junhee Yoo <[email protected]> --------- Signed-off-by: Junhee Yoo <[email protected]>
1 parent 873279b commit 4c9388f

File tree

3 files changed

+299
-19
lines changed

3 files changed

+299
-19
lines changed

ggml/src/ggml-metal.m

Lines changed: 111 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
241241
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
242242
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
243243
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
244+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
245+
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
244246
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
245247
GGML_METAL_KERNEL_TYPE_PAD_F32,
246248
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -272,6 +274,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
272274
GGML_METAL_KERNEL_TYPE_SIN,
273275
GGML_METAL_KERNEL_TYPE_COS,
274276
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
277+
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
278+
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
275279

276280
GGML_METAL_KERNEL_TYPE_COUNT
277281
};
@@ -685,6 +689,8 @@ @implementation GGMLMetalClass
685689
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
686690
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
687691
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
692+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
693+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
688694
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
689695
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
690696
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -716,6 +722,8 @@ @implementation GGMLMetalClass
716722
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
717723
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
718724
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
725+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
726+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
719727
}
720728

721729
[metal_library release];
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
844852
case GGML_OP_IM2COL:
845853
return op->src[0]->type == GGML_TYPE_F16;
846854
case GGML_OP_POOL_1D:
847-
case GGML_OP_POOL_2D:
848855
return false;
856+
case GGML_OP_POOL_2D:
849857
case GGML_OP_UPSCALE:
850858
case GGML_OP_PAD:
851859
case GGML_OP_ARANGE:
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
25452553
} break;
25462554
case GGML_OP_IM2COL:
25472555
{
2556+
GGML_ASSERT(ggml_is_contiguous(src0));
2557+
GGML_ASSERT(ggml_is_contiguous(src1));
25482558
GGML_ASSERT(src0->type == GGML_TYPE_F16);
25492559
GGML_ASSERT(src1->type == GGML_TYPE_F32);
25502560
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
25742584
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
25752585
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
25762586

2577-
id<MTLComputePipelineState> pipeline = nil;
2587+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
2588+
2589+
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
25782590

25792591
switch (dst->type) {
2580-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
2581-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
2592+
case GGML_TYPE_F32: {
2593+
pipeline = (is_gt_mttpt ?
2594+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
2595+
:
2596+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
2597+
} break;
2598+
case GGML_TYPE_F16: {
2599+
pipeline = (is_gt_mttpt ?
2600+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
2601+
:
2602+
ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
2603+
} break;
25822604
default: GGML_ABORT("fatal error");
25832605
};
25842606

25852607
[encoder setComputePipelineState:pipeline];
2586-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2587-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2588-
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
2589-
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
2590-
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
2591-
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
2592-
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
2593-
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
2594-
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
2595-
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
2596-
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
2597-
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
2598-
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
2599-
2600-
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2608+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
2609+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2610+
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
2611+
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
2612+
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
2613+
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
2614+
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
2615+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
2616+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
2617+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
2618+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
2619+
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
2620+
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
2621+
2622+
if (is_gt_mttpt) {
2623+
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
2624+
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
2625+
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
2626+
2627+
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
2628+
2629+
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
2630+
2631+
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
2632+
} else {
2633+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
2634+
}
26012635
} break;
26022636
case GGML_OP_UPSCALE:
26032637
{
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
30013035

30023036
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
30033037
} break;
3038+
case GGML_OP_POOL_2D:
3039+
{
3040+
GGML_ASSERT(ggml_is_contiguous(src0));
3041+
GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
3042+
3043+
const int32_t * opts = dst->op_params;
3044+
enum ggml_op_pool op = opts[0];
3045+
3046+
id<MTLComputePipelineState> pipeline = nil;
3047+
switch (src0t) {
3048+
case GGML_TYPE_F32: {
3049+
switch(op) {
3050+
case GGML_OP_POOL_AVG:
3051+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
3052+
case GGML_OP_POOL_MAX:
3053+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
3054+
default: GGML_ASSERT(false && "not implemented");
3055+
}
3056+
} break;
3057+
default: GGML_ASSERT(false && "not implemented");
3058+
}
3059+
3060+
const int32_t k0 = opts[1];
3061+
const int32_t k1 = opts[2];
3062+
const int32_t s0 = opts[3];
3063+
const int32_t s1 = opts[4];
3064+
const int32_t p0 = opts[5];
3065+
const int32_t p1 = opts[6];
3066+
3067+
const int64_t IH = src0->ne[1];
3068+
const int64_t IW = src0->ne[0];
3069+
3070+
const int64_t N = dst->ne[3];
3071+
const int64_t OC = dst->ne[2];
3072+
const int64_t OH = dst->ne[1];
3073+
const int64_t OW = dst->ne[0];
3074+
3075+
const int64_t parallel_elements = N * OC * OH * OW;
3076+
const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
3077+
const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
3078+
3079+
[encoder setComputePipelineState:pipeline];
3080+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3081+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
3082+
[encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
3083+
[encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
3084+
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
3085+
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
3086+
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
3087+
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
3088+
[encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
3089+
[encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
3090+
[encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
3091+
[encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
3092+
[encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
3093+
3094+
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
3095+
} break;
30043096
default:
30053097
{
30063098
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

ggml/src/ggml-metal.metal

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,85 @@ kernel void kernel_im2col(
19331933
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
19341934
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
19351935

1936+
typedef void (im2col_ext_t)(
1937+
device const float * x,
1938+
device char * dst,
1939+
constant int32_t & ofs0,
1940+
constant int32_t & ofs1,
1941+
constant int32_t & IW,
1942+
constant int32_t & IH,
1943+
constant int32_t & CHW,
1944+
constant int32_t & s0,
1945+
constant int32_t & s1,
1946+
constant int32_t & p0,
1947+
constant int32_t & p1,
1948+
constant int32_t & d0,
1949+
constant int32_t & d1,
1950+
constant int32_t & N,
1951+
constant int32_t & KH,
1952+
constant int32_t & KW,
1953+
uint3 tgpig[[threadgroup_position_in_grid]],
1954+
uint3 tgpg[[threadgroups_per_grid]],
1955+
uint3 tpitg[[thread_position_in_threadgroup]],
1956+
uint3 ntg[[threads_per_threadgroup]]);
1957+
1958+
template <typename T>
1959+
kernel void kernel_im2col_ext(
1960+
device const float * x,
1961+
device char * dst,
1962+
constant int32_t & ofs0,
1963+
constant int32_t & ofs1,
1964+
constant int32_t & IW,
1965+
constant int32_t & IH,
1966+
constant int32_t & CHW,
1967+
constant int32_t & s0,
1968+
constant int32_t & s1,
1969+
constant int32_t & p0,
1970+
constant int32_t & p1,
1971+
constant int32_t & d0,
1972+
constant int32_t & d1,
1973+
constant int32_t & N,
1974+
constant int32_t & KH,
1975+
constant int32_t & KW,
1976+
uint3 tgpig[[threadgroup_position_in_grid]],
1977+
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
1978+
uint3 tpitg[[thread_position_in_threadgroup]],
1979+
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
1980+
const int32_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2]
1981+
1982+
const int32_t d = tgpig[0] / CHW;
1983+
const int32_t chw = tgpig[0] % CHW;
1984+
const int32_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
1985+
const int32_t HW = tgpig[0] % KHW;
1986+
1987+
const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
1988+
if (tpitg_0 >= N) {
1989+
return;
1990+
}
1991+
1992+
const int32_t tpitg_1 = HW / KW;
1993+
const int32_t tpitg_2 = HW % KW;
1994+
1995+
const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
1996+
const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
1997+
1998+
const int32_t offset_dst =
1999+
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
2000+
(tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
2001+
2002+
device T * pdst = (device T *) (dst);
2003+
2004+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
2005+
pdst[offset_dst] = 0.0f;
2006+
} else {
2007+
const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
2008+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
2009+
}
2010+
}
2011+
2012+
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
2013+
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
2014+
19362015
kernel void kernel_upscale_f32(
19372016
device const char * src0,
19382017
device char * dst,
@@ -6372,3 +6451,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t
63726451
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
63736452
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
63746453
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
6454+
6455+
kernel void kernel_pool_2d_max_f32(
6456+
device const float * src0,
6457+
device float * dst,
6458+
constant int32_t & k0,
6459+
constant int32_t & k1,
6460+
constant int32_t & s0,
6461+
constant int32_t & s1,
6462+
constant int32_t & p0,
6463+
constant int32_t & p1,
6464+
constant int64_t & IH,
6465+
constant int64_t & IW,
6466+
constant int64_t & OH,
6467+
constant int64_t & OW,
6468+
constant int64_t & parallel_elements,
6469+
uint gid[[thread_position_in_grid]]) {
6470+
6471+
if (gid >= parallel_elements) {
6472+
return;
6473+
}
6474+
6475+
const int idx = gid;
6476+
const int I_HW = IH * IW;
6477+
const int O_HW = OH * OW;
6478+
const int nc = idx / O_HW;
6479+
const int cur_oh = idx % O_HW / OW;
6480+
const int cur_ow = idx % O_HW % OW;
6481+
6482+
device const float * i_ptr = src0 + nc * I_HW;
6483+
device float * o_ptr = dst + nc * O_HW;
6484+
6485+
const int start_h = cur_oh * s1 - p1;
6486+
const int bh = MAX(0, start_h);
6487+
const int eh = MIN(IH, start_h + k1);
6488+
const int start_w = cur_ow * s0 - p0;
6489+
const int bw = MAX(0, start_w);
6490+
const int ew = MIN(IW, start_w + k0);
6491+
6492+
float res = -INFINITY;
6493+
6494+
for (int i = bh; i < eh; i += 1) {
6495+
for (int j = bw; j < ew; j += 1) {
6496+
res = MAX(res, i_ptr[i * IW + j]);
6497+
}
6498+
}
6499+
6500+
o_ptr[cur_oh * OW + cur_ow] = res;
6501+
}
6502+
6503+
kernel void kernel_pool_2d_avg_f32(
6504+
device const float * src0,
6505+
device float * dst,
6506+
constant int32_t & k0,
6507+
constant int32_t & k1,
6508+
constant int32_t & s0,
6509+
constant int32_t & s1,
6510+
constant int32_t & p0,
6511+
constant int32_t & p1,
6512+
constant int64_t & IH,
6513+
constant int64_t & IW,
6514+
constant int64_t & OH,
6515+
constant int64_t & OW,
6516+
constant int64_t & parallel_elements,
6517+
uint gid[[thread_position_in_grid]]) {
6518+
6519+
if (gid >= parallel_elements) {
6520+
return;
6521+
}
6522+
6523+
const int idx = gid;
6524+
const int I_HW = IH * IW;
6525+
const int O_HW = OH * OW;
6526+
const int nc = idx / O_HW;
6527+
const int cur_oh = idx % O_HW / OW;
6528+
const int cur_ow = idx % O_HW % OW;
6529+
6530+
device const float * i_ptr = src0 + nc * I_HW;
6531+
device float * o_ptr = dst + nc * O_HW;
6532+
6533+
const int start_h = cur_oh * s1 - p1;
6534+
const int bh = MAX(0, start_h);
6535+
const int eh = MIN(IH, start_h + k1);
6536+
const int start_w = cur_ow * s0 - p0;
6537+
const int bw = MAX(0, start_w);
6538+
const int ew = MIN(IW, start_w + k0);
6539+
// const float scale = 1. / ((eh - bh) * (ew - bw));
6540+
const float scale = 1. / (k0 * k1);
6541+
6542+
float res = 0;
6543+
6544+
for (int i = bh; i < eh; i += 1) {
6545+
for (int j = bw; j < ew; j += 1) {
6546+
float cur = i_ptr[i * IW + j];
6547+
res += cur * scale;
6548+
}
6549+
}
6550+
6551+
o_ptr[cur_oh * OW + cur_ow] = res;
6552+
}

tests/test-backend-ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,6 +3316,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
33163316
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33173317
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33183318

3319+
// test cases for 2D im2col
3320+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
3321+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
3322+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
3323+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
3324+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
3325+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
3326+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
3327+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
3328+
33193329
// sycl backend will limit task global_range < MAX_INT
33203330
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
33213331
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)

0 commit comments

Comments
 (0)