Skip to content

[Transform] Add basic onednn_graph dialect lowering #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
}

def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
let description = [{
Lowers the `onednn_graph` ops to `linalg` ops.
}];
let dependentDialects = [
"func::FuncDialect",
"math::MathDialect",
"arith::ArithDialect",
"tensor::TensorDialect",
"linalg::LinalgDialect"
];
}

#endif // GC_DIALECT_GC_PASSES
2 changes: 2 additions & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_library(GCPasses
OneDNNGraphToLinalg.cpp
TileNamed.cpp

ADDITIONAL_HEADER_DIRS
Expand All @@ -9,6 +10,7 @@ add_mlir_library(GCPasses

LINK_LIBS PUBLIC
${mlir_dialect_libs}
MLIROneDNNGraph
MLIRIR
MLIRSupport
MLIRBufferizationToMemRef
Expand Down
280 changes: 280 additions & 0 deletions lib/gc/Transforms/OneDNNGraphToLinalg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=//
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <numeric>
#include <vector>

#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h"
#include "gc/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir::onednn_graph;

namespace mlir {
namespace gc {
#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG
#include "gc/Transforms/Passes.h.inc"

namespace {
//===----------------------------------------------------------------------===//
// Util funcs
//===----------------------------------------------------------------------===//

Value createBroadcastOperand(Location loc, PatternRewriter &rewriter,
TensorType ty, Value op) {
auto opTy = dyn_cast<TensorType>(op.getType());
llvm::ArrayRef<int64_t> bcastShape = ty.getShape();
llvm::ArrayRef<int64_t> opShape = opTy.getShape();
int64_t diff = bcastShape.size() - opShape.size();

if (bcastShape.equals(opShape)) {
return op;
} else {
// get broadcast dimensions
llvm::SmallVector<int64_t> bcastDims;
for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) {
int64_t idxOp = i - diff;
if (idxOp < 0) {
bcastDims.push_back(i);
} else if (bcastShape[i] != opShape[idxOp]) {
bcastDims.push_back(i);
}
}
// create a new output tensor
Value initTensor =
rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType());
return rewriter
.create<linalg::BroadcastOp>(
/*location=*/loc,
/*inputs=*/op,
/*inits=*/initTensor,
/*dimensions=*/bcastDims)
.getResults()
.front();
}
}

typedef Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The definition is a bit confusing since there's mlir::Value (even though the result is still the same). If there's any better option - go for it. Otherwise it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a type define for a function that returns mlir::Value.


struct OriginalOperand {
template <unsigned I>
static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) {
if (I >= op->getNumOperands()) {
op->emitError("Index exceeds operand num.\n");
return nullptr;
}
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I));
}
};

struct ConstantOperand {
template <int64_t I>
static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) {
const auto loc = op->getLoc();
if (llvm::isa<IntegerType>(ty.getElementType())) {
return b.create<arith::ConstantOp>( //
loc, DenseElementsAttr::get(ty, int64_t(I)));
} else if (llvm::isa<FloatType>(ty.getElementType())) {
return b.create<arith::ConstantOp>( //
loc, DenseElementsAttr::get(ty, float(I)));
} else {
op->emitError("Not a supported element type for constant.\n");
return nullptr;
}
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods are stateless. What are they wrapped into a struct for? Is it just for logical grouping expecting more methods?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for logical grouping, will add more, e.g. get constant operand from attrs like onednn_graph.pow


//===----------------------------------------------------------------------===//
// Elemwise lowering
//===----------------------------------------------------------------------===//

// Generate elementwise op using linalg named ops
template <typename LoweredOp>
Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty,
llvm::ArrayRef<Value> inputs) {
// create a new output tensor
Value outTensor =
rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());

auto elemwiseOp = rewriter.create<LoweredOp>(
/*location=*/loc,
/*resultTensorTypes=*/outTensor.getType(),
/*inputs=*/inputs,
/*outputs=*/outTensor);

return elemwiseOp.getResult(0);
}

template <typename UnaryOp, typename LoweredOp, GetOperandFn GetOperand>
struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> {
using OpRewritePattern<UnaryOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnaryOp op,
PatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
auto inOp = GetOperand(op, rewriter, resultTy);
if (!inOp) {
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
}
auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, {inOp});
rewriter.replaceOp(op, unaryOp);
return success();
}
};

template <typename BinaryOp, typename LoweredOp, GetOperandFn GetOperandLHS,
GetOperandFn GetOperandRHS>
struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> {
using OpRewritePattern<BinaryOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BinaryOp op,
PatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
auto lhsOp = GetOperandLHS(op, rewriter, resultTy);
auto rhsOp = GetOperandRHS(op, rewriter, resultTy);
if (!lhsOp || !rhsOp) {
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
}
auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, //
{lhsOp, rhsOp});
rewriter.replaceOp(op, binaryOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// Op lowering
//===----------------------------------------------------------------------===//

using ReLUOpLowering =
BinaryElemwiseLowering<onednn_graph::ReLUOp, linalg::MaxOp, //
OriginalOperand::getIdx<0>,
ConstantOperand::getConst<0>>;

using AddOpLowering =
BinaryElemwiseLowering<onednn_graph::AddOp, linalg::AddOp, //
OriginalOperand::getIdx<0>,
OriginalOperand::getIdx<1>>;

//===----------------------------------------------------------------------===//
// MatMulOp lowering
//===----------------------------------------------------------------------===//

struct MatMulOpLowering : public OpRewritePattern<MatMulOp> {
using OpRewritePattern<MatMulOp>::OpRewritePattern;
LogicalResult matchAndRewrite(MatMulOp op,
PatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
//
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(resultTy.getElementType()));
Value newTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultTy.getElementType());
Value outTensor =
rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);

bool transposeA = op.getTransposeA();
bool transposeB = op.getTransposeB();
Operation *newOp;
if (!transposeA && !transposeB) {
// (A * B)
newOp = rewriter.create<linalg::MatmulOp>(
/*location=*/loc,
/*resultTensorTypes=*/resultTy,
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
/*outputs=*/outTensor);
} else if (transposeA && !transposeB) {
// T(A) * B
newOp = rewriter.create<linalg::MatmulTransposeAOp>(
/*location=*/loc,
/*resultTensorTypes=*/resultTy,
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
/*outputs=*/outTensor);
} else if (!transposeA && transposeB) {
// A * T(B)
newOp = rewriter.create<linalg::MatmulTransposeBOp>(
/*location=*/loc,
/*resultTensorTypes=*/resultTy,
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
/*outputs=*/outTensor);
} else {
// T(B * A)
int64_t rank = resultTy.getRank();
SmallVector<int64_t> permutation(rank);
std::iota(std::begin(permutation), std::end(permutation), 0);
permutation[rank - 2] = rank - 1;
permutation[rank - 1] = rank - 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this already assumes it's 2d, can it be static?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for 2D matmul it should be static

auto matmulOp = rewriter.create<linalg::MatmulOp>(
/*location=*/loc,
/*resultTensorTypes=*/resultTy,
/*inputs=*/ValueRange{op.getInputB(), op.getInputA()},
/*outputs=*/outTensor);
newOp = rewriter.create<linalg::TransposeOp>(
/*location=*/loc,
/*inputs=*/matmulOp.getResult(0),
/*outputs=*/outTensor,
/*permutation=*/permutation);
}

if (op.getBias()) {
auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias());
newOp = rewriter.create<linalg::AddOp>(
/*location=*/loc,
/*resultTensorTypes=*/outTensor.getType(),
/*inputs=*/newOp->getResult(0),
/*outputs=*/bias);
}

rewriter.replaceOp(op, newOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass define
//===----------------------------------------------------------------------===//

struct ConvertOneDNNGraphToLinalg
: public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> {

void runOnOperation() final {
auto *ctx = &getContext();
// add lowering target
ConversionTarget target(getContext());
target.addIllegalDialect<onednn_graph::OneDNNGraphDialect>();
target.addLegalDialect<BuiltinDialect, arith::ArithDialect,
linalg::LinalgDialect, func::FuncDialect,
tensor::TensorDialect>();
// set pattern
RewritePatternSet patterns(ctx);
patterns.add<AddOpLowering>(ctx);
patterns.add<ReLUOpLowering>(ctx);
patterns.add<MatMulOpLowering>(ctx);
// perform conversion
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace
} // namespace gc
} // namespace mlir
38 changes: 38 additions & 0 deletions test/gc/Dialect/OneDNNGraph/onednn-graph-to-linalg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: gc-opt --split-input-file --convert-onednn-graph-to-linalg %s -verify-diagnostics -o -| FileCheck %s

// CHECK-LABEL: @matmul
func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> {
// CHECK: [[C0:%.+]] = arith.constant 0
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
// CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16>
return %0 : tensor<128x256xbf16>
}

// CHECK-LABEL: @add
func.func @add(%arg0: tensor<128x256xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
// CHECK: tensor.empty()
// CHECK: linalg.add
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<128x256xf32>) -> tensor<128x256xf32>
return %0 : tensor<128x256xf32>
}

// CHECK-LABEL: @add_bcast
func.func @add_bcast(%arg0: tensor<128x256xf32>, %arg1: tensor<256xf32>) -> tensor<128x256xf32> {
// CHECK: tensor.empty()
// CHECK: linalg.broadcast
// CHECK: tensor.empty()
// CHECK: linalg.add
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<256xf32>) -> tensor<128x256xf32>
return %0 : tensor<128x256xf32>
}

// CHECK-LABEL: @relu
func.func @relu(%arg0: tensor<128x256xf32>) -> tensor<128x256xf32> {
// CHECK: arith.constant dense<0.0{{.*}}>
// CHECK: tensor.empty()
// CHECK: linalg.max
%0 = onednn_graph.relu %arg0 : (tensor<128x256xf32>) -> tensor<128x256xf32>
return %0 : tensor<128x256xf32>
}