|
| 1 | +//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=// |
| 2 | +//-*-===// |
| 3 | +// |
| 4 | +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | + |
| 10 | +#include <numeric> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" |
| 14 | +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" |
| 15 | +#include "gc-dialects/Passes.h" |
| 16 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 17 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 18 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | +#include "mlir/Dialect/Math/IR/Math.h" |
| 20 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 21 | +#include "mlir/IR/PatternMatch.h" |
| 22 | +#include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 23 | +#include "mlir/Support/LogicalResult.h" |
| 24 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | + |
| 26 | +using namespace mlir::onednn_graph; |
| 27 | + |
| 28 | +namespace mlir { |
| 29 | +namespace gc { |
| 30 | +#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG |
| 31 | +#include "gc-dialects/Passes.h.inc" |
| 32 | + |
| 33 | +namespace { |
| 34 | +//===----------------------------------------------------------------------===// |
| 35 | +// Util funcs |
| 36 | +//===----------------------------------------------------------------------===// |
| 37 | + |
| 38 | +// Create lowered elementwise Op |
| 39 | +struct CreateElementwiseOp { |
| 40 | + virtual ~CreateElementwiseOp() = default; |
| 41 | + virtual Value create(OpBuilder &b, Location loc, ValueRange args) const = 0; |
| 42 | +}; |
| 43 | + |
| 44 | +// Generate elementwise op using linalg::GenericOp |
| 45 | +Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty, |
| 46 | + llvm::ArrayRef<Value> inputs, |
| 47 | + const CreateElementwiseOp &createOp) { |
| 48 | + // create indexing maps for elemwise_op as all identity maps |
| 49 | + llvm::SmallVector<AffineMap> indexingMaps( // |
| 50 | + inputs.size() + 1, // |
| 51 | + rewriter.getMultiDimIdentityMap(ty.getRank())); |
| 52 | + // create iterator types all "parallel", no axis for "reduction" |
| 53 | + llvm::SmallVector<utils::IteratorType> iteratorTypes( // |
| 54 | + ty.getRank(), // |
| 55 | + utils::IteratorType::parallel); |
| 56 | + |
| 57 | + // create a new output tensor |
| 58 | + Value outTensor = |
| 59 | + rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType()); |
| 60 | + |
| 61 | + auto elemwiseOp = rewriter.create<linalg::GenericOp>( |
| 62 | + /*location=*/loc, |
| 63 | + /*resultTensorTypes=*/outTensor.getType(), |
| 64 | + /*inputs=*/inputs, |
| 65 | + /*outputs=*/outTensor, |
| 66 | + /*indexingMaps=*/indexingMaps, |
| 67 | + /*iteratorTypes=*/iteratorTypes, |
| 68 | + [&](OpBuilder &b, Location loc, ValueRange args) { |
| 69 | + Value result = createOp.create(b, loc, args); |
| 70 | + b.create<linalg::YieldOp>(loc, result); |
| 71 | + }); |
| 72 | + |
| 73 | + return elemwiseOp.getResult(0); |
| 74 | +} |
| 75 | + |
| 76 | +Value createBroadcastOperand(Location loc, PatternRewriter &rewriter, |
| 77 | + TensorType ty, Value op) { |
| 78 | + auto opTy = dyn_cast<TensorType>(op.getType()); |
| 79 | + llvm::ArrayRef<int64_t> bcastShape = ty.getShape(); |
| 80 | + llvm::ArrayRef<int64_t> opShape = opTy.getShape(); |
| 81 | + int64_t diff = bcastShape.size() - opShape.size(); |
| 82 | + // |
| 83 | + if (bcastShape.equals(opShape)) { |
| 84 | + return op; |
| 85 | + } else { |
| 86 | + // get broadcast dimensions |
| 87 | + llvm::SmallVector<int64_t> bcastDims; |
| 88 | + for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) { |
| 89 | + int64_t idxOp = i - diff; |
| 90 | + if (idxOp < 0) { |
| 91 | + bcastDims.push_back(i); |
| 92 | + } else if (bcastShape[i] != opShape[idxOp]) { |
| 93 | + bcastDims.push_back(i); |
| 94 | + } |
| 95 | + } |
| 96 | + // create a new output tensor |
| 97 | + Value initTensor = |
| 98 | + rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType()); |
| 99 | + return rewriter |
| 100 | + .create<linalg::BroadcastOp>( |
| 101 | + /*location=*/loc, |
| 102 | + /*inputs=*/op, |
| 103 | + /*inits=*/initTensor, |
| 104 | + /*dimensions=*/bcastDims) |
| 105 | + .getResults() |
| 106 | + .front(); |
| 107 | + } |
| 108 | +} |
| 109 | + |
| 110 | +//===----------------------------------------------------------------------===// |
| 111 | +// UnaryOp lowering |
| 112 | +//===----------------------------------------------------------------------===// |
| 113 | + |
| 114 | +template <typename UnaryOp, typename CreateLoweredOp> |
| 115 | +struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> { |
| 116 | + using OpRewritePattern<UnaryOp>::OpRewritePattern; |
| 117 | + LogicalResult matchAndRewrite(UnaryOp op, |
| 118 | + PatternRewriter &rewriter) const final { |
| 119 | + auto loc = op->getLoc(); |
| 120 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 121 | + // |
| 122 | + auto unaryOp = createElemwiseOp(loc, rewriter, resultTy, |
| 123 | + {op->getOperand(0)}, CreateLoweredOp()); |
| 124 | + rewriter.replaceOp(op, unaryOp); |
| 125 | + return success(); |
| 126 | + } |
| 127 | +}; |
| 128 | + |
| 129 | +template <typename LoweredOp> |
| 130 | +struct CreateLoweredUnaryOp : public CreateElementwiseOp { |
| 131 | + Value create(OpBuilder &b, Location loc, ValueRange args) const final { |
| 132 | + return b.create<LoweredOp>(loc, args[0]); |
| 133 | + } |
| 134 | +}; |
| 135 | + |
| 136 | +struct CreateLoweredReLUOp : public CreateElementwiseOp { |
| 137 | + Value create(OpBuilder &b, Location loc, ValueRange args) const final { |
| 138 | + Value input = args[0]; |
| 139 | + Value zeros = |
| 140 | + b.create<arith::ConstantOp>(loc, FloatAttr::get(input.getType(), 0.f)); |
| 141 | + return b.create<arith::MaximumFOp>(loc, input, zeros); |
| 142 | + } |
| 143 | +}; |
| 144 | + |
| 145 | +//===----------------------------------------------------------------------===// |
| 146 | +// BinaryOp lowering |
| 147 | +//===----------------------------------------------------------------------===// |
| 148 | + |
| 149 | +template <typename BinaryOp, typename CreateLoweredOp> |
| 150 | +struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> { |
| 151 | + using OpRewritePattern<BinaryOp>::OpRewritePattern; |
| 152 | + LogicalResult matchAndRewrite(BinaryOp op, |
| 153 | + PatternRewriter &rewriter) const final { |
| 154 | + auto loc = op->getLoc(); |
| 155 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 156 | + // |
| 157 | + auto lhsOp = |
| 158 | + createBroadcastOperand(loc, rewriter, resultTy, op->getOperand(0)); |
| 159 | + auto rhsOp = |
| 160 | + createBroadcastOperand(loc, rewriter, resultTy, op->getOperand(1)); |
| 161 | + // |
| 162 | + auto binaryOp = createElemwiseOp(loc, rewriter, resultTy, {lhsOp, rhsOp}, |
| 163 | + CreateLoweredOp()); |
| 164 | + rewriter.replaceOp(op, binaryOp); |
| 165 | + return success(); |
| 166 | + } |
| 167 | +}; |
| 168 | + |
| 169 | +template <typename LoweredOp> |
| 170 | +struct CreateLoweredBinaryOp : public CreateElementwiseOp { |
| 171 | + Value create(OpBuilder &b, Location loc, ValueRange args) const final { |
| 172 | + return b.create<LoweredOp>(loc, args[0], args[1]); |
| 173 | + } |
| 174 | +}; |
| 175 | + |
| 176 | +//===----------------------------------------------------------------------===// |
| 177 | +// Op lowering |
| 178 | +//===----------------------------------------------------------------------===// |
| 179 | + |
| 180 | +using ReLUOpLowering = UnaryElemwiseLowering< // |
| 181 | + onednn_graph::ReLUOp, CreateLoweredReLUOp>; |
| 182 | +// using ExpOpLowering = UnaryElemwiseLowering< // |
| 183 | +// onednn_graph::ExpOp, CreateLoweredUnaryOp<math::ExpOp>>; |
| 184 | + |
| 185 | +using AddOpLowering = BinaryElemwiseLowering< // |
| 186 | + onednn_graph::AddOp, CreateLoweredBinaryOp<arith::AddFOp>>; |
| 187 | +// using SubOpLowering = BinaryElemwiseLowering< // |
| 188 | +// onednn_graph::SubOp, CreateLoweredBinaryOp<arith::SubFOp>>; |
| 189 | +// using MulOpLowering = BinaryElemwiseLowering< // |
| 190 | +// onednn_graph::MulOp, CreateLoweredBinaryOp<arith::MulFOp>>; |
| 191 | +// using DivOpLowering = BinaryElemwiseLowering< // |
| 192 | +// onednn_graph::DivOp, CreateLoweredBinaryOp<arith::DivFOp>>; |
| 193 | + |
| 194 | +//===----------------------------------------------------------------------===// |
| 195 | +// MatMulOp lowering |
| 196 | +//===----------------------------------------------------------------------===// |
| 197 | + |
| 198 | +struct MatMulOpLowering : public OpRewritePattern<MatMulOp> { |
| 199 | + using OpRewritePattern<MatMulOp>::OpRewritePattern; |
| 200 | + LogicalResult matchAndRewrite(MatMulOp op, |
| 201 | + PatternRewriter &rewriter) const final { |
| 202 | + auto loc = op->getLoc(); |
| 203 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 204 | + // |
| 205 | + Value newTensor = rewriter.create<tensor::EmptyOp>( |
| 206 | + loc, resultTy.getShape(), resultTy.getElementType()); |
| 207 | + Value zero = rewriter.create<arith::ConstantOp>( |
| 208 | + loc, rewriter.getZeroAttr(resultTy.getElementType())); |
| 209 | + Value outTensor = |
| 210 | + rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0); |
| 211 | + |
| 212 | + bool transposeA = op.getTransposeA(); |
| 213 | + bool transposeB = op.getTransposeB(); |
| 214 | + Operation *newOp; |
| 215 | + if (!transposeA && !transposeB) { |
| 216 | + // (A * B) |
| 217 | + newOp = rewriter.create<linalg::MatmulOp>( |
| 218 | + /*location=*/loc, |
| 219 | + /*resultTensorTypes=*/resultTy, |
| 220 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 221 | + /*outputs=*/outTensor); |
| 222 | + } else if (transposeA && !transposeB) { |
| 223 | + // T(A) * B |
| 224 | + newOp = rewriter.create<linalg::MatmulTransposeAOp>( |
| 225 | + /*location=*/loc, |
| 226 | + /*resultTensorTypes=*/resultTy, |
| 227 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 228 | + /*outputs=*/outTensor); |
| 229 | + } else if (!transposeA && transposeB) { |
| 230 | + // A * T(B) |
| 231 | + newOp = rewriter.create<linalg::MatmulTransposeBOp>( |
| 232 | + /*location=*/loc, |
| 233 | + /*resultTensorTypes=*/resultTy, |
| 234 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 235 | + /*outputs=*/outTensor); |
| 236 | + } else { |
| 237 | + // T(B * A) |
| 238 | + int64_t rank = resultTy.getRank(); |
| 239 | + SmallVector<int64_t> permutation(rank); |
| 240 | + std::iota(std::begin(permutation), std::end(permutation), 0); |
| 241 | + permutation[rank - 2] = rank - 1; |
| 242 | + permutation[rank - 1] = rank - 2; |
| 243 | + auto matmulOp = rewriter.create<linalg::MatmulOp>( |
| 244 | + /*location=*/loc, |
| 245 | + /*resultTensorTypes=*/resultTy, |
| 246 | + /*inputs=*/ValueRange{op.getInputB(), op.getInputA()}, |
| 247 | + /*outputs=*/outTensor); |
| 248 | + newOp = rewriter.create<linalg::TransposeOp>( |
| 249 | + /*location=*/loc, |
| 250 | + /*inputs=*/matmulOp.getResult(0), |
| 251 | + /*outputs=*/outTensor, |
| 252 | + /*permutation=*/permutation); |
| 253 | + } |
| 254 | + |
| 255 | + if (op.getBias()) { |
| 256 | + auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias()); |
| 257 | + newOp = rewriter.create<linalg::AddOp>( |
| 258 | + /*location=*/loc, |
| 259 | + /*resultTensorTypes=*/outTensor.getType(), |
| 260 | + /*inputs=*/newOp->getResult(0), |
| 261 | + /*outputs=*/bias); |
| 262 | + } |
| 263 | + |
| 264 | + rewriter.replaceOp(op, newOp); |
| 265 | + return success(); |
| 266 | + } |
| 267 | +}; |
| 268 | + |
| 269 | +//===----------------------------------------------------------------------===// |
| 270 | +// Pass define |
| 271 | +//===----------------------------------------------------------------------===// |
| 272 | + |
| 273 | +struct ConvertOneDNNGraphToLinalg |
| 274 | + : public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> { |
| 275 | + |
| 276 | + void runOnOperation() final { |
| 277 | + // |
| 278 | + auto *ctx = &getContext(); |
| 279 | + RewritePatternSet patterns(ctx); |
| 280 | + patterns.add<AddOpLowering>(ctx); |
| 281 | + patterns.add<ReLUOpLowering>(ctx); |
| 282 | + patterns.add<MatMulOpLowering>(ctx); |
| 283 | + // |
| 284 | + if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| 285 | + std::move(patterns)))) { |
| 286 | + signalPassFailure(); |
| 287 | + } |
| 288 | + } |
| 289 | +}; |
| 290 | + |
| 291 | +} // namespace |
| 292 | +} // namespace gc |
| 293 | +} // namespace mlir |
0 commit comments