|
| 1 | +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Implement Winograd Conv2D algorithm. The implementation is based on the |
| 10 | +// paper: Fast Algorithms for Convolutional Neural Networks |
| 11 | +// (https://arxiv.org/abs/1509.09308) |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 16 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
| 18 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 19 | + |
| 20 | +namespace mlir { |
| 21 | +namespace linalg { |
| 22 | + |
| 23 | +namespace { |
| 24 | + |
| 25 | +using TransformMapKeyTy = std::pair<int, int>; |
| 26 | + |
| 27 | +// We use F(m, r) to define the size of minimal filtering algorithms. |
| 28 | +// m is the output dimension and r is the filter dimension. We can get |
| 29 | +// the input dimension, alpha, from the formula, alpha = m + r - 1. |
| 30 | +// |
| 31 | +// For example, when m = 2 and r = 3, we know its input size is 4. |
| 32 | +// The Conv2D will operate on 4x4 input data with 3x3 filter and get |
| 33 | +// 2x2 output result. |
| 34 | +constexpr TransformMapKeyTy F_2_3{2, 3}; |
| 35 | +constexpr TransformMapKeyTy F_4_3{4, 3}; |
| 36 | +constexpr TransformMapKeyTy F_2_5{2, 5}; |
| 37 | + |
| 38 | +Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { |
| 39 | + auto type = cast<ShapedType>(data.getType()); |
| 40 | + auto elementType = type.getElementType(); |
| 41 | + auto shape = type.getShape(); |
| 42 | + auto collapseType = RankedTensorType::get( |
| 43 | + {shape[0] * shape[1], shape[2], shape[3]}, elementType); |
| 44 | + SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}}; |
| 45 | + return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data, |
| 46 | + reassociation); |
| 47 | +} |
| 48 | + |
| 49 | +// This function generates linalg.batch_matmul to multiply input with filter. |
| 50 | +// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W |
| 51 | +// data as the 1-dimension data array. That is to convert [H, W, N, C] to |
| 52 | +// [H x W, N, C]. In this way, we can convert 4-dimension input data to |
| 53 | +// 3-dimension representation that is suitable for linalg.batch_matmul. |
| 54 | +// |
| 55 | +// Batched matmul will do the matrix multiply with the reduction on channel. |
| 56 | +// |
| 57 | +// We get |
| 58 | +// |
| 59 | +// %collapsed_input = tensor.collapse_shape %input |
| 60 | +// %collapsed_filter = tensor.collapse_shape %filter |
| 61 | +// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter |
| 62 | +// %expanded_ret = tensor.expand_shape %ret |
| 63 | +// |
| 64 | +// After this function, we get return value with data layout (H, W, N, F) |
| 65 | +// |
| 66 | +Value matrixMultiply(RewriterBase &rewriter, Location loc, |
| 67 | + Value transformedFilter, Value transformedInput) { |
| 68 | + auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); |
| 69 | + auto collapseInput = collapse2DData(rewriter, loc, transformedInput); |
| 70 | + |
| 71 | + // Batched matrix multiply |
| 72 | + auto filterType = cast<ShapedType>(transformedFilter.getType()); |
| 73 | + auto filterShape = filterType.getShape(); |
| 74 | + auto inputType = cast<ShapedType>(transformedInput.getType()); |
| 75 | + auto inputElemType = inputType.getElementType(); |
| 76 | + auto inputShape = inputType.getShape(); |
| 77 | + |
| 78 | + auto matmulType = RankedTensorType::get( |
| 79 | + {inputShape[0] * inputShape[1], inputShape[2], filterShape[3]}, |
| 80 | + inputElemType); |
| 81 | + Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), |
| 82 | + inputElemType); |
| 83 | + |
| 84 | + auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( |
| 85 | + loc, matmulType, ValueRange({collapseInput, collapseFilter}), |
| 86 | + ValueRange{init}); |
| 87 | + |
| 88 | + // Expand matmul result |
| 89 | + SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}}; |
| 90 | + auto expandType = RankedTensorType::get( |
| 91 | + {inputShape[0], inputShape[1], inputShape[2], filterShape[3]}, |
| 92 | + inputElemType); |
| 93 | + auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( |
| 94 | + loc, expandType, matmulOp.getResult(0), reassociation); |
| 95 | + return expandOutput; |
| 96 | +} |
| 97 | + |
| 98 | +FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter, |
| 99 | + linalg::Conv2DNhwcFhwcOp convOp, |
| 100 | + int64_t m, int64_t r) { |
| 101 | + Value input = convOp.getInputs()[0]; |
| 102 | + Value filter = convOp.getInputs()[1]; |
| 103 | + Value output = convOp.getOutputs()[0]; |
| 104 | + |
| 105 | + auto outputType = cast<ShapedType>(output.getType()); |
| 106 | + int64_t outputH = outputType.getShape()[1]; |
| 107 | + int64_t outputW = outputType.getShape()[2]; |
| 108 | + auto filterType = cast<ShapedType>(filter.getType()); |
| 109 | + auto filterShape = filterType.getShape(); // F, H, W, C |
| 110 | + int64_t filterF = filterShape[0]; |
| 111 | + int64_t filterH = filterShape[1]; |
| 112 | + int64_t filterW = filterShape[2]; |
| 113 | + int64_t filterC = filterShape[3]; |
| 114 | + auto inputType = cast<ShapedType>(input.getType()); |
| 115 | + auto inputShape = inputType.getShape(); // N, H, W, C |
| 116 | + int64_t inputN = inputShape[0]; |
| 117 | + int64_t inputC = inputShape[3]; |
| 118 | + |
| 119 | + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r) |
| 120 | + if ((outputH != outputW) && (outputH != 1 && outputW != 1)) |
| 121 | + return failure(); |
| 122 | + if ((filterH != filterW) && (filterH != 1 && filterW != 1)) |
| 123 | + return failure(); |
| 124 | + |
| 125 | + if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1)) |
| 126 | + return failure(); |
| 127 | + if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1)) |
| 128 | + return failure(); |
| 129 | + |
| 130 | + // Map from (m, r) to G transform matrix. |
| 131 | + static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = { |
| 132 | + F_2_3, F_4_3, F_2_5}; |
| 133 | + |
| 134 | + TransformMapKeyTy key = {m, r}; |
| 135 | + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); |
| 136 | + // If we cannot find the constant transformation matrix, it means we do |
| 137 | + // not support this configuration yet. |
| 138 | + if (it == validConfigs.end()) |
| 139 | + return failure(); |
| 140 | + |
| 141 | + // All the criterias are satisfied. We can do Winograd Conv2D. |
| 142 | + Location loc = convOp.getLoc(); |
| 143 | + |
| 144 | + // For F(m x 1, r x 1), we only need to do left side transform. |
| 145 | + bool leftTransform = outputH != 1; |
| 146 | + // For F(1 x m, 1 x r), we only need to do right side transform. |
| 147 | + bool rightTransform = outputW != 1; |
| 148 | + |
| 149 | + // Create operator for filter transform |
| 150 | + Type elementType = filterType.getElementType(); |
| 151 | + int64_t alphaH = leftTransform ? m + r - 1 : 1; |
| 152 | + int64_t alphaW = rightTransform ? m + r - 1 : 1; |
| 153 | + int64_t retHeight = leftTransform ? (outputH / m) * alphaH : 1; |
| 154 | + int64_t retWidth = rightTransform ? (outputW / m) * alphaW : 1; |
| 155 | + auto retType = RankedTensorType::get({retHeight, retWidth, filterC, filterF}, |
| 156 | + elementType); |
| 157 | + Value retValue = |
| 158 | + rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType); |
| 159 | + auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( |
| 160 | + loc, retType, filter, retValue, outputH, outputW, m, r); |
| 161 | + |
| 162 | + // Create operator for input transform |
| 163 | + retType = |
| 164 | + RankedTensorType::get({retHeight, retWidth, inputN, inputC}, elementType); |
| 165 | + retValue = |
| 166 | + rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType); |
| 167 | + auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( |
| 168 | + loc, retType, input, retValue, outputH, outputW, m, r); |
| 169 | + |
| 170 | + Value matmulRet = |
| 171 | + matrixMultiply(rewriter, loc, transformedFilter, transformedInput); |
| 172 | + |
| 173 | + // create operator for output transform |
| 174 | + auto transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( |
| 175 | + loc, outputType, matmulRet, output, m, r); |
| 176 | + |
| 177 | + rewriter.replaceOp(convOp, transformedOutput); |
| 178 | + |
| 179 | + return transformedOutput.getOperation(); |
| 180 | +} |
| 181 | + |
| 182 | +class WinogradConv2DNhwcFhwc final |
| 183 | + : public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> { |
| 184 | +public: |
| 185 | + using OpRewritePattern::OpRewritePattern; |
| 186 | + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) |
| 187 | + : OpRewritePattern(context), m(m), r(r) {} |
| 188 | + |
| 189 | + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, |
| 190 | + PatternRewriter &rewriter) const override { |
| 191 | + Value filter = convOp.getInputs()[1]; |
| 192 | + auto filterType = cast<ShapedType>(filter.getType()); |
| 193 | + auto filterShape = filterType.getShape(); // F, H, W, C |
| 194 | + int64_t filterH = filterShape[1]; |
| 195 | + int64_t filterW = filterShape[2]; |
| 196 | + Value output = convOp.getOutputs()[0]; |
| 197 | + auto outputType = cast<ShapedType>(output.getType()); |
| 198 | + auto outputShape = outputType.getShape(); // F, H, W, C |
| 199 | + int64_t outputH = outputShape[1]; |
| 200 | + int64_t outputW = outputShape[2]; |
| 201 | + |
| 202 | + if (filterH != r && filterH != 1 && filterW != r && filterW != 1) |
| 203 | + return failure(); |
| 204 | + |
| 205 | + if (outputH < m && outputH != 1 && outputW < m && outputW != 1) |
| 206 | + return failure(); |
| 207 | + |
| 208 | + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) |
| 209 | + return failure(); |
| 210 | + |
| 211 | + return success(); |
| 212 | + } |
| 213 | + |
| 214 | +private: |
| 215 | + int64_t m; |
| 216 | + int64_t r; |
| 217 | +}; |
| 218 | +} // end anonymous namespace |
| 219 | + |
| 220 | +//===----------------------------------------------------------------------===// |
| 221 | +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, |
| 222 | + int64_t r) { |
| 223 | + MLIRContext *context = patterns.getContext(); |
| 224 | + patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r); |
| 225 | +} |
| 226 | + |
| 227 | +} // end namespace linalg |
| 228 | +} // end namespace mlir |
0 commit comments