|
| 1 | +//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=// |
| 2 | +//-*-===// |
| 3 | +// |
| 4 | +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// |
| 8 | +//===----------------------------------------------------------------------===// |
| 9 | + |
| 10 | +#include <numeric> |
| 11 | +#include <vector> |
| 12 | + |
| 13 | +#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h" |
| 14 | +#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h" |
| 15 | +#include "gc-dialects/Passes.h" |
| 16 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 17 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 18 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | +#include "mlir/Dialect/Math/IR/Math.h" |
| 20 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 21 | +#include "mlir/IR/PatternMatch.h" |
| 22 | +#include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 23 | +#include "mlir/Support/LogicalResult.h" |
| 24 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | + |
| 26 | +using namespace mlir::onednn_graph; |
| 27 | + |
| 28 | +namespace mlir { |
| 29 | +namespace gc { |
| 30 | +#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG |
| 31 | +#include "gc-dialects/Passes.h.inc" |
| 32 | + |
| 33 | +namespace { |
| 34 | +//===----------------------------------------------------------------------===// |
| 35 | +// Util funcs |
| 36 | +//===----------------------------------------------------------------------===// |
| 37 | + |
| 38 | +Value createBroadcastOperand(Location loc, PatternRewriter &rewriter, |
| 39 | + TensorType ty, Value op) { |
| 40 | + auto opTy = dyn_cast<TensorType>(op.getType()); |
| 41 | + llvm::ArrayRef<int64_t> bcastShape = ty.getShape(); |
| 42 | + llvm::ArrayRef<int64_t> opShape = opTy.getShape(); |
| 43 | + int64_t diff = bcastShape.size() - opShape.size(); |
| 44 | + |
| 45 | + if (bcastShape.equals(opShape)) { |
| 46 | + return op; |
| 47 | + } else { |
| 48 | + // get broadcast dimensions |
| 49 | + llvm::SmallVector<int64_t> bcastDims; |
| 50 | + for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) { |
| 51 | + int64_t idxOp = i - diff; |
| 52 | + if (idxOp < 0) { |
| 53 | + bcastDims.push_back(i); |
| 54 | + } else if (bcastShape[i] != opShape[idxOp]) { |
| 55 | + bcastDims.push_back(i); |
| 56 | + } |
| 57 | + } |
| 58 | + // create a new output tensor |
| 59 | + Value initTensor = |
| 60 | + rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType()); |
| 61 | + return rewriter |
| 62 | + .create<linalg::BroadcastOp>( |
| 63 | + /*location=*/loc, |
| 64 | + /*inputs=*/op, |
| 65 | + /*inits=*/initTensor, |
| 66 | + /*dimensions=*/bcastDims) |
| 67 | + .getResults() |
| 68 | + .front(); |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +typedef Value (*OperandGet)(Operation *, PatternRewriter &, TensorType); |
| 73 | + |
| 74 | +template <unsigned I> |
| 75 | +Value OriginalOperand(Operation *op, PatternRewriter &b, TensorType ty) { |
| 76 | + return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I)); |
| 77 | +} |
| 78 | + |
| 79 | +static Value ConstZeroOperand(Operation *op, PatternRewriter &b, |
| 80 | + TensorType ty) { |
| 81 | + auto loc = op->getLoc(); |
| 82 | + Value zero = |
| 83 | + b.create<arith::ConstantOp>(loc, b.getZeroAttr(ty.getElementType())); |
| 84 | + Value newTensor = |
| 85 | + b.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType()); |
| 86 | + return b.create<linalg::FillOp>(loc, zero, newTensor).getResult(0); |
| 87 | +} |
| 88 | + |
| 89 | +//===----------------------------------------------------------------------===// |
| 90 | +// Elemwise lowering |
| 91 | +//===----------------------------------------------------------------------===// |
| 92 | + |
| 93 | +// Generate elementwise op using linalg named ops |
| 94 | +template <typename LoweredOp> |
| 95 | +Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty, |
| 96 | + llvm::ArrayRef<Value> inputs) { |
| 97 | + // create a new output tensor |
| 98 | + Value outTensor = |
| 99 | + rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType()); |
| 100 | + |
| 101 | + auto elemwiseOp = rewriter.create<LoweredOp>( |
| 102 | + /*location=*/loc, |
| 103 | + /*resultTensorTypes=*/outTensor.getType(), |
| 104 | + /*inputs=*/inputs, |
| 105 | + /*outputs=*/outTensor); |
| 106 | + |
| 107 | + return elemwiseOp.getResult(0); |
| 108 | +} |
| 109 | + |
| 110 | +template <typename UnaryOp, typename LoweredOp> |
| 111 | +struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> { |
| 112 | + using OpRewritePattern<UnaryOp>::OpRewritePattern; |
| 113 | + LogicalResult matchAndRewrite(UnaryOp op, |
| 114 | + PatternRewriter &rewriter) const final { |
| 115 | + auto loc = op->getLoc(); |
| 116 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 117 | + auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, // |
| 118 | + {op->getOperand(0)}); |
| 119 | + rewriter.replaceOp(op, unaryOp); |
| 120 | + return success(); |
| 121 | + } |
| 122 | +}; |
| 123 | + |
| 124 | +template <typename BinaryOp, typename LoweredOp, OperandGet GetOperandLHS, |
| 125 | + OperandGet GetOperandRHS> |
| 126 | +struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> { |
| 127 | + using OpRewritePattern<BinaryOp>::OpRewritePattern; |
| 128 | + LogicalResult matchAndRewrite(BinaryOp op, |
| 129 | + PatternRewriter &rewriter) const final { |
| 130 | + auto loc = op->getLoc(); |
| 131 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 132 | + auto lhsOp = GetOperandLHS(op, rewriter, resultTy); |
| 133 | + auto rhsOp = GetOperandRHS(op, rewriter, resultTy); |
| 134 | + auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, // |
| 135 | + {lhsOp, rhsOp}); |
| 136 | + rewriter.replaceOp(op, binaryOp); |
| 137 | + return success(); |
| 138 | + } |
| 139 | +}; |
| 140 | + |
| 141 | +//===----------------------------------------------------------------------===// |
| 142 | +// Op lowering |
| 143 | +//===----------------------------------------------------------------------===// |
| 144 | + |
| 145 | +using ReLUOpLowering = BinaryElemwiseLowering< // |
| 146 | + onednn_graph::ReLUOp, linalg::MaxOp, OriginalOperand<0>, ConstZeroOperand>; |
| 147 | + |
| 148 | +using AddOpLowering = BinaryElemwiseLowering< // |
| 149 | + onednn_graph::AddOp, linalg::AddOp, OriginalOperand<0>, OriginalOperand<1>>; |
| 150 | + |
| 151 | +//===----------------------------------------------------------------------===// |
| 152 | +// MatMulOp lowering |
| 153 | +//===----------------------------------------------------------------------===// |
| 154 | + |
| 155 | +struct MatMulOpLowering : public OpRewritePattern<MatMulOp> { |
| 156 | + using OpRewritePattern<MatMulOp>::OpRewritePattern; |
| 157 | + LogicalResult matchAndRewrite(MatMulOp op, |
| 158 | + PatternRewriter &rewriter) const final { |
| 159 | + auto loc = op->getLoc(); |
| 160 | + auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front()); |
| 161 | + // |
| 162 | + Value newTensor = rewriter.create<tensor::EmptyOp>( |
| 163 | + loc, resultTy.getShape(), resultTy.getElementType()); |
| 164 | + Value zero = rewriter.create<arith::ConstantOp>( |
| 165 | + loc, rewriter.getZeroAttr(resultTy.getElementType())); |
| 166 | + Value outTensor = |
| 167 | + rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0); |
| 168 | + |
| 169 | + bool transposeA = op.getTransposeA(); |
| 170 | + bool transposeB = op.getTransposeB(); |
| 171 | + Operation *newOp; |
| 172 | + if (!transposeA && !transposeB) { |
| 173 | + // (A * B) |
| 174 | + newOp = rewriter.create<linalg::MatmulOp>( |
| 175 | + /*location=*/loc, |
| 176 | + /*resultTensorTypes=*/resultTy, |
| 177 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 178 | + /*outputs=*/outTensor); |
| 179 | + } else if (transposeA && !transposeB) { |
| 180 | + // T(A) * B |
| 181 | + newOp = rewriter.create<linalg::MatmulTransposeAOp>( |
| 182 | + /*location=*/loc, |
| 183 | + /*resultTensorTypes=*/resultTy, |
| 184 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 185 | + /*outputs=*/outTensor); |
| 186 | + } else if (!transposeA && transposeB) { |
| 187 | + // A * T(B) |
| 188 | + newOp = rewriter.create<linalg::MatmulTransposeBOp>( |
| 189 | + /*location=*/loc, |
| 190 | + /*resultTensorTypes=*/resultTy, |
| 191 | + /*inputs=*/ValueRange{op.getInputA(), op.getInputB()}, |
| 192 | + /*outputs=*/outTensor); |
| 193 | + } else { |
| 194 | + // T(B * A) |
| 195 | + int64_t rank = resultTy.getRank(); |
| 196 | + SmallVector<int64_t> permutation(rank); |
| 197 | + std::iota(std::begin(permutation), std::end(permutation), 0); |
| 198 | + permutation[rank - 2] = rank - 1; |
| 199 | + permutation[rank - 1] = rank - 2; |
| 200 | + auto matmulOp = rewriter.create<linalg::MatmulOp>( |
| 201 | + /*location=*/loc, |
| 202 | + /*resultTensorTypes=*/resultTy, |
| 203 | + /*inputs=*/ValueRange{op.getInputB(), op.getInputA()}, |
| 204 | + /*outputs=*/outTensor); |
| 205 | + newOp = rewriter.create<linalg::TransposeOp>( |
| 206 | + /*location=*/loc, |
| 207 | + /*inputs=*/matmulOp.getResult(0), |
| 208 | + /*outputs=*/outTensor, |
| 209 | + /*permutation=*/permutation); |
| 210 | + } |
| 211 | + |
| 212 | + if (op.getBias()) { |
| 213 | + auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias()); |
| 214 | + newOp = rewriter.create<linalg::AddOp>( |
| 215 | + /*location=*/loc, |
| 216 | + /*resultTensorTypes=*/outTensor.getType(), |
| 217 | + /*inputs=*/newOp->getResult(0), |
| 218 | + /*outputs=*/bias); |
| 219 | + } |
| 220 | + |
| 221 | + rewriter.replaceOp(op, newOp); |
| 222 | + return success(); |
| 223 | + } |
| 224 | +}; |
| 225 | + |
| 226 | +//===----------------------------------------------------------------------===// |
| 227 | +// Pass define |
| 228 | +//===----------------------------------------------------------------------===// |
| 229 | + |
| 230 | +struct ConvertOneDNNGraphToLinalg |
| 231 | + : public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> { |
| 232 | + |
| 233 | + void runOnOperation() final { |
| 234 | + // |
| 235 | + auto *ctx = &getContext(); |
| 236 | + RewritePatternSet patterns(ctx); |
| 237 | + patterns.add<AddOpLowering>(ctx); |
| 238 | + patterns.add<ReLUOpLowering>(ctx); |
| 239 | + patterns.add<MatMulOpLowering>(ctx); |
| 240 | + // |
| 241 | + if (failed(applyPatternsAndFoldGreedily(getOperation(), |
| 242 | + std::move(patterns)))) { |
| 243 | + signalPassFailure(); |
| 244 | + } |
| 245 | + } |
| 246 | +}; |
| 247 | + |
| 248 | +} // namespace |
| 249 | +} // namespace gc |
| 250 | +} // namespace mlir |
0 commit comments