@@ -15,7 +15,7 @@ struct kernel_bounds {
15
15
int x_min, x_max;
16
16
};
17
17
18
- __device__ inline kernel_bounds calculate_kernel_bounds (int out_x, int out_y, const conv_params & params) {
18
+ __device__ __forceinline__ kernel_bounds calculate_kernel_bounds (int out_x, int out_y, const conv_params & params) {
19
19
kernel_bounds bounds;
20
20
bounds.y_min = max (0 , (params.padding_y - out_y * params.stride_y + params.dilation_y - 1 ) / params.dilation_y );
21
21
bounds.y_max =
@@ -28,7 +28,7 @@ __device__ inline kernel_bounds calculate_kernel_bounds(int out_x, int out_y, co
28
28
return bounds;
29
29
}
30
30
31
- __device__ inline int calculate_input_coord (int out_coord, int kern_coord, int stride, int dilation, int padding) {
31
+ __device__ __forceinline__ int calculate_input_coord (int out_coord, int kern_coord, int stride, int dilation, int padding) {
32
32
return out_coord * stride + kern_coord * dilation - padding;
33
33
}
34
34
@@ -84,8 +84,8 @@ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restr
84
84
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85
85
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86
86
const int channels, const int batches) {
87
- int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
88
- int total_elements = batches * channels * out_h * out_w;
87
+ const int global_idx = blockIdx .x * blockDim .x + threadIdx .x ;
88
+ const int total_elements = batches * channels * out_h * out_w;
89
89
90
90
if (global_idx >= total_elements) {
91
91
return ;
0 commit comments