Skip to content

Commit 790824c

Browse files
committed
Review: fix operation ordering in ggml-cuda, use __forceinline__, use more const
1 parent fd61f86 commit 790824c

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

ggml/src/ggml-cuda/conv2d-dw.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct kernel_bounds {
1515
int x_min, x_max;
1616
};
1717

18-
__device__ inline kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
18+
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
1919
kernel_bounds bounds;
2020
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
2121
bounds.y_max =
@@ -28,7 +28,7 @@ __device__ inline kernel_bounds calculate_kernel_bounds(int out_x, int out_y, co
2828
return bounds;
2929
}
3030

31-
__device__ inline int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
31+
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
3232
return out_coord * stride + kern_coord * dilation - padding;
3333
}
3434

@@ -84,8 +84,8 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
8484
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
8585
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
8686
const int channels, const int batches) {
87-
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
88-
int total_elements = batches * channels * out_h * out_w;
87+
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
88+
const int total_elements = batches * channels * out_h * out_w;
8989

9090
if (global_idx >= total_elements) {
9191
return;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,6 +2311,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23112311
case GGML_OP_IM2COL:
23122312
ggml_cuda_op_im2col(ctx, dst);
23132313
break;
2314+
case GGML_OP_CONV_2D_DW:
2315+
ggml_cuda_op_conv2d_dw(ctx, dst);
2316+
break;
23142317
case GGML_OP_CONV_TRANSPOSE_1D:
23152318
ggml_cuda_op_conv_transpose_1d(ctx,dst);
23162319
break;
@@ -2353,9 +2356,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23532356
case GGML_OP_OPT_STEP_ADAMW:
23542357
ggml_cuda_opt_step_adamw(ctx, dst);
23552358
break;
2356-
case GGML_OP_CONV_2D_DW:
2357-
ggml_cuda_op_conv2d_dw(ctx, dst);
2358-
break;
23592359
default:
23602360
return false;
23612361
}
@@ -3213,6 +3213,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32133213
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
32143214
}
32153215
case GGML_OP_IM2COL:
3216+
case GGML_OP_CONV_2D_DW:
32163217
case GGML_OP_POOL_2D:
32173218
case GGML_OP_SUM:
32183219
case GGML_OP_SUM_ROWS:
@@ -3267,7 +3268,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32673268
case GGML_OP_CROSS_ENTROPY_LOSS:
32683269
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32693270
case GGML_OP_OPT_STEP_ADAMW:
3270-
case GGML_OP_CONV_2D_DW:
32713271
return true;
32723272
default:
32733273
return false;

0 commit comments

Comments
 (0)