|
| 1 | +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | + |
| 3 | +#include <executorch/backends/cadence/reference/operators/operators.h> |
| 4 | + |
| 5 | +#include <algorithm> |
| 6 | + |
| 7 | +namespace impl { |
| 8 | +namespace reference { |
| 9 | +namespace native { |
| 10 | + |
| 11 | +using ::executorch::aten::IntArrayRef; |
| 12 | +using ::executorch::aten::ScalarType; |
| 13 | +using ::executorch::aten::Tensor; |
| 14 | +using ::executorch::runtime::KernelRuntimeContext; |
| 15 | + |
| 16 | +template <typename T> |
| 17 | +__attribute__((always_inline)) void im2row_( |
| 18 | + const T* __restrict__ data_im, |
| 19 | + const int32_t in_zero_point, |
| 20 | + /* input parameters*/ |
| 21 | + const int32_t channels, |
| 22 | + const int32_t height, |
| 23 | + const int32_t width, |
| 24 | + /* output parameters */ |
| 25 | + const int32_t out_height, |
| 26 | + const int32_t out_width, |
| 27 | + /* convolution parameters */ |
| 28 | + const int32_t kernel_h, |
| 29 | + const int32_t kernel_w, |
| 30 | + const int32_t pad_h, |
| 31 | + const int32_t pad_w, |
| 32 | + const int32_t stride_h, |
| 33 | + const int32_t stride_w, |
| 34 | + const int32_t dilation_h, |
| 35 | + const int32_t dilation_w, |
| 36 | + T* __restrict__ data_col, |
| 37 | + bool channels_last) { |
| 38 | + // Consider convolving the input image of dimensions channels * height * width |
| 39 | + // (or height * width * channels for NHWC layout) with a filter of dimensions |
| 40 | + // channels * kernels_h * kernels_w. Assume that this convolution will produce |
| 41 | + // an output of dimensinos out_height x out_width. For each point the output, |
| 42 | + // im2row takes the data from the input that is used in the computation of |
| 43 | + // that output point, and flattens it into a vector of size channels_col = |
| 44 | + // channels * kernel_h * kernel_w. The output of im2row will therefore be a 2D |
| 45 | + // array of size (out_height * out_width) x channels_col |
| 46 | + const int32_t channels_col = channels * kernel_h * kernel_w; |
| 47 | + |
| 48 | + // If the layout is NHWC, we can copy 'channels' worth of contiguous data |
| 49 | + // points when performing im2row. |
| 50 | + if (channels_last) { |
| 51 | + // Iterate over the output domain |
| 52 | + for (int _h = 0; _h < out_height; ++_h) { |
| 53 | + for (int _w = 0; _w < out_width; ++_w) { |
| 54 | + int32_t i_col = _h * out_width + _w; |
| 55 | + // Each point in the output domain is the result of applying a filter of |
| 56 | + // size kernel_h x kernel_w x channels on the input. But since channels |
| 57 | + // is contiguous, we will not explicitly have a loop for it. |
| 58 | + for (int _kh = 0; _kh < kernel_h; ++_kh) { |
| 59 | + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; |
| 60 | + for (int _kw = 0; _kw < kernel_w; ++_kw) { |
| 61 | + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; |
| 62 | + |
| 63 | + // h_im and w_im are the actual height and width coordinates of the |
| 64 | + // input tensor from where we need to copy 'channels' points. |
| 65 | + const T* __restrict__ slice_im = |
| 66 | + data_im + (h_im * width + w_im) * channels; |
| 67 | + T* __restrict__ slice_col = data_col + i_col * channels_col + |
| 68 | + (_kh * kernel_w + _kw) * channels; |
| 69 | + // If the coordinates were within the input domain, we copy |
| 70 | + // 'channels' contiguous values. Otherwise we will fill the output |
| 71 | + // with 0's. |
| 72 | + if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { |
| 73 | + std::memcpy(slice_col, slice_im, channels * sizeof(T)); |
| 74 | + } else { |
| 75 | + std::fill_n(slice_col, channels, T(in_zero_point)); |
| 76 | + } |
| 77 | + } |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + } else { |
| 82 | + // Iterate over the output domain |
| 83 | + for (int _h = 0; _h < out_height; ++_h) { |
| 84 | + for (int _w = 0; _w < out_width; ++_w) { |
| 85 | + int32_t i_col = _h * out_width + _w; |
| 86 | + |
| 87 | + // Each point in the output domain is the result of applying a filter |
| 88 | + // of size chanenls * kernel_h x kernel_w on the input |
| 89 | + for (int _c = 0; _c < channels; ++_c) { |
| 90 | + for (int _kh = 0; _kh < kernel_h; ++_kh) { |
| 91 | + for (int _kw = 0; _kw < kernel_w; ++_kw) { |
| 92 | + // c_col is the linearized access in the channels_col vector. |
| 93 | + int32_t c_col = (_c * kernel_h + _kh) * kernel_w + _kw; |
| 94 | + // h_im and w_im are the actual height and width coordinates of |
| 95 | + // the input tensor that we need to copy to the output. |
| 96 | + int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; |
| 97 | + int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; |
| 98 | + // If the current data access is within the input tensor, copy the |
| 99 | + // value |
| 100 | + data_col[i_col * channels_col + c_col] = |
| 101 | + (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) |
| 102 | + ? data_im[(_c * height + h_im) * width + w_im] |
| 103 | + : static_cast<T>(in_zero_point); |
| 104 | + } |
| 105 | + } |
| 106 | + } |
| 107 | + } |
| 108 | + } |
| 109 | + } |
| 110 | +} |
| 111 | + |
| 112 | +void im2row_out( |
| 113 | + __ET_UNUSED KernelRuntimeContext& ctx, |
| 114 | + const Tensor& input, |
| 115 | + IntArrayRef kernel_size, |
| 116 | + IntArrayRef dilation, |
| 117 | + IntArrayRef padding, |
| 118 | + IntArrayRef stride, |
| 119 | + const Tensor& in_zero_point, |
| 120 | + bool channel_last, |
| 121 | + Tensor& out) { |
| 122 | + // Compute the input tensor's dims |
| 123 | + bool unit_height = input.dim() == 3; |
| 124 | + const int32_t batch_size = input.size(0); |
| 125 | + const int32_t in_c = |
| 126 | + channel_last ? input.size(3 - unit_height) : input.size(1); |
| 127 | + const int32_t in_h = |
| 128 | + unit_height ? 1 : (channel_last ? input.size(1) : input.size(2)); |
| 129 | + const int32_t in_w = |
| 130 | + channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height); |
| 131 | + |
| 132 | + // Get the kernel parameters |
| 133 | + int32_t kernel_h = kernel_size[0]; |
| 134 | + int32_t kernel_w = kernel_size[1]; |
| 135 | + int32_t dilation_h = dilation[0]; |
| 136 | + int32_t dilation_w = dilation[1]; |
| 137 | + int32_t pad_h = padding[0]; |
| 138 | + int32_t pad_w = padding[1]; |
| 139 | + int32_t stride_h = stride[0]; |
| 140 | + int32_t stride_w = stride[1]; |
| 141 | + |
| 142 | + // If we were to apply a convolution on the input tensor, compute the output |
| 143 | + // height and width. |
| 144 | + int32_t out_h = |
| 145 | + (in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1; |
| 146 | + int32_t out_w = |
| 147 | + (in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1; |
| 148 | + |
| 149 | + ET_DCHECK_MSG( |
| 150 | + (out_h * out_w) == out.size(1), "dimension mismatch for output"); |
| 151 | + ET_DCHECK_MSG( |
| 152 | + (kernel_h * kernel_w * in_c) == out.size(2), |
| 153 | + "dimension mismatch for output"); |
| 154 | + |
| 155 | + // Check if the input is per-tensor quantized or per-channel quantized. The |
| 156 | + // zero point for each batch could differ for per-channel quantized input. |
| 157 | + bool per_tensor_quantized = in_zero_point.numel() == 1; |
| 158 | + |
| 159 | +#define typed_im2row(dtype, ctype) \ |
| 160 | + case ScalarType::dtype: { \ |
| 161 | + const ctype* __restrict__ in_data = input.const_data_ptr<ctype>(); \ |
| 162 | + ctype* __restrict__ out_data = out.mutable_data_ptr<ctype>(); \ |
| 163 | + const int32_t* __restrict__ zero_point = \ |
| 164 | + in_zero_point.const_data_ptr<int32_t>(); \ |
| 165 | + int32_t in_plane = in_c * in_h * in_w; \ |
| 166 | + int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \ |
| 167 | + for (size_t n = 0; n < batch_size; ++n) { \ |
| 168 | + im2row_<ctype>( \ |
| 169 | + &in_data[n * in_plane], \ |
| 170 | + per_tensor_quantized ? zero_point[0] : zero_point[n], \ |
| 171 | + in_c, \ |
| 172 | + in_h, \ |
| 173 | + in_w, \ |
| 174 | + out_h, \ |
| 175 | + out_w, \ |
| 176 | + kernel_h, \ |
| 177 | + kernel_w, \ |
| 178 | + pad_h, \ |
| 179 | + pad_w, \ |
| 180 | + stride_h, \ |
| 181 | + stride_w, \ |
| 182 | + dilation_h, \ |
| 183 | + dilation_w, \ |
| 184 | + &out_data[n * out_plane], \ |
| 185 | + channel_last); \ |
| 186 | + } \ |
| 187 | + break; \ |
| 188 | + } |
| 189 | + |
| 190 | + ScalarType dtype = input.scalar_type(); |
| 191 | + switch (dtype) { |
| 192 | + typed_im2row(Float, float); |
| 193 | + typed_im2row(Byte, uint8_t); |
| 194 | + typed_im2row(Char, int8_t); |
| 195 | + default: |
| 196 | + ET_DCHECK_MSG( |
| 197 | + false, |
| 198 | + "im2row not implemented for dtype %s", |
| 199 | + torch::executor::toString(dtype)); |
| 200 | + } |
| 201 | +#undef typed_im2row |
| 202 | +} |
| 203 | + |
| 204 | +} // namespace native |
| 205 | +} // namespace reference |
| 206 | +} // namespace impl |
0 commit comments