Skip to content

Commit bc01448

Browse files
committed
cuda : restore im2col
ggml-ci
1 parent 936af26 commit bc01448

File tree

1 file changed

+13
-25
lines changed

1 file changed

+13
-25
lines changed

ggml-cuda.cu

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5246,30 +5246,19 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
52465246

52475247
static __global__ void im2col_f32_f16(
52485248
const float * x, half * dst,
5249-
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
5249+
int ofs0, int ofs1, int IW, int IH, int CHW,
52505250
int s0, int s1, int p0, int p1, int d0, int d1) {
5251-
const int i = threadIdx.x + blockIdx.x * blockDim.x;
5252-
if (i >= pelements) {
5253-
return;
5254-
}
5255-
5256-
const int ksize = OW * (KH > 1 ? KW : 1);
5257-
const int kx = i / ksize;
5258-
const int kd = kx * ksize;
5259-
const int ky = (i - kd) / OW;
5260-
const int ix = i % OW;
5261-
5262-
const int iiw = ix * s0 + kx * d0 - p0;
5263-
const int iih = blockIdx.y * s1 + ky * d1 - p1;
5251+
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
5252+
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
52645253

52655254
const int offset_dst =
5266-
(blockIdx.y * OW + ix) * CHW +
5267-
(blockIdx.z * (KW * KH) + ky * KW + kx);
5255+
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
5256+
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
52685257

52695258
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
52705259
dst[offset_dst] = __float2half(0.0f);
52715260
} else {
5272-
const int offset_src = blockIdx.z * offset_delta;
5261+
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
52735262
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
52745263
}
52755264
}
@@ -6502,14 +6491,13 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
65026491
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
65036492
}
65046493

6505-
static void im2col_f32_f16_cuda(const float* x, half* dst,
6506-
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
6507-
int offset_delta,
6508-
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
6509-
const int parallel_elements = OW * KW * KH;
6510-
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
6511-
dim3 block_nums(num_blocks, OH, IC);
6512-
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
6494+
static void im2col_f32_f16_cuda(const float * x, half * dst,
6495+
int OH, int IW, int IH, int OW, int IC,
6496+
int KH, int KW, int N, int ofs0, int ofs1,
6497+
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
6498+
dim3 block_nums(IC, OH, OW);
6499+
dim3 block_dims(N, KH, KW);
6500+
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
65136501
}
65146502

65156503
// buffer pool for cuda

0 commit comments

Comments
 (0)