Skip to content

Commit 80273a3

Browse files
JohannesGaesslerggerganov
authored andcommitted
CUDA: fix 1D im2col, add tests (ggml/993)
1 parent c19af0a commit 80273a3

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3141,7 +3141,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31413141
case GGML_OP_ROPE:
31423142
return ggml_is_contiguous(op->src[0]);
31433143
case GGML_OP_IM2COL:
3144-
return op->src[0]->type == GGML_TYPE_F16;
31453144
case GGML_OP_POOL_2D:
31463145
case GGML_OP_SUM:
31473146
case GGML_OP_SUM_ROWS:

ggml/src/ggml-cuda/im2col.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
9191
const int64_t OH = is_2D ? dst->ne[2] : 1;
9292
const int64_t OW = dst->ne[1];
9393

94-
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
95-
const int64_t batch = src1->ne[3];
96-
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
94+
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
95+
const int64_t batch = src1->ne[is_2D ? 3 : 2];
96+
const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
9797

9898
if(dst->type == GGML_TYPE_F16) {
9999
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);

tests/test-backend-ops.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
33083308
}
33093309
}
33103310

3311-
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3312-
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3313-
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3314-
// test cases for 1D im2col
3311+
// im2col 1D
33153312
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33163313
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
33173314
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
3315+
for (int s0 : {1, 3}) {
3316+
for (int p0 : {0, 3}) {
3317+
for (int d0 : {1, 3}) {
3318+
test_cases.emplace_back(new test_im2col(
3319+
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
3320+
s0, 0, p0, 0, d0, 0, false));
3321+
}
3322+
}
3323+
}
3324+
3325+
// im2col 2D
3326+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3327+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3328+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3329+
for (int s0 : {1, 3}) {
3330+
for (int s1 : {1, 3}) {
3331+
for (int p0 : {0, 3}) {
3332+
for (int p1 : {0, 3}) {
3333+
for (int d0 : {1, 3}) {
3334+
for (int d1 : {1, 3}) {
3335+
test_cases.emplace_back(new test_im2col(
3336+
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
3337+
s0, s1, p0, p1, d0, d1, true));
3338+
}
3339+
}
3340+
}
3341+
}
3342+
}
3343+
}
33183344

3319-
// test cases for 2D im2col
3345+
// extra tests for im2col 2D
33203346
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
33213347
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
33223348
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));

0 commit comments

Comments
 (0)