@@ -5246,30 +5246,19 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
5246
5246
5247
5247
static __global__ void im2col_f32_f16 (
5248
5248
const float * x, half * dst,
5249
- int offset_delta , int IW , int IH , int OW, int KW, int KH, int pelements , int CHW,
5249
+ int ofs0 , int ofs1 , int IW , int IH , int CHW,
5250
5250
int s0, int s1, int p0, int p1, int d0, int d1) {
5251
- const int i = threadIdx .x + blockIdx .x * blockDim .x ;
5252
- if (i >= pelements) {
5253
- return ;
5254
- }
5255
-
5256
- const int ksize = OW * (KH > 1 ? KW : 1 );
5257
- const int kx = i / ksize;
5258
- const int kd = kx * ksize;
5259
- const int ky = (i - kd) / OW;
5260
- const int ix = i % OW;
5261
-
5262
- const int iiw = ix * s0 + kx * d0 - p0;
5263
- const int iih = blockIdx .y * s1 + ky * d1 - p1;
5251
+ const int iiw = blockIdx .z * s0 + threadIdx .z * d0 - p0;
5252
+ const int iih = blockIdx .y * s1 + threadIdx .y * d1 - p1;
5264
5253
5265
5254
const int offset_dst =
5266
- (blockIdx .y * OW + ix ) * CHW +
5267
- (blockIdx .z * (KW * KH ) + ky * KW + kx );
5255
+ (threadIdx . x * gridDim . y * gridDim . z + blockIdx .y * gridDim . z + blockIdx . z ) * CHW +
5256
+ (blockIdx .x * (blockDim . y * blockDim . z ) + threadIdx . y * blockDim . z + threadIdx . z );
5268
5257
5269
5258
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
5270
5259
dst[offset_dst] = __float2half (0 .0f );
5271
5260
} else {
5272
- const int offset_src = blockIdx .z * offset_delta ;
5261
+ const int offset_src = threadIdx . x * ofs0 + blockIdx .x * ofs1 ;
5273
5262
dst[offset_dst] = __float2half (x[offset_src + iih * IW + iiw]);
5274
5263
}
5275
5264
}
@@ -6502,14 +6491,13 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
6502
6491
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols_x, nrows_y, scale);
6503
6492
}
6504
6493
6505
- static void im2col_f32_f16_cuda (const float * x, half* dst,
6506
- int IW, int IH, int OW, int OH, int KW, int KH, int IC,
6507
- int offset_delta,
6508
- int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
6509
- const int parallel_elements = OW * KW * KH;
6510
- const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1 ) / CUDA_IM2COL_BLOCK_SIZE;
6511
- dim3 block_nums (num_blocks, OH, IC);
6512
- im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0 , stream>>> (x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
6494
+ static void im2col_f32_f16_cuda (const float * x, half * dst,
6495
+ int OH, int IW, int IH, int OW, int IC,
6496
+ int KH, int KW, int N, int ofs0, int ofs1,
6497
+ int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
6498
+ dim3 block_nums (IC, OH, OW);
6499
+ dim3 block_dims (N, KH, KW);
6500
+ im2col_f32_f16<<<block_nums, block_dims, 0 , stream>>> (x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
6513
6501
}
6514
6502
6515
6503
// buffer pool for cuda
0 commit comments