Skip to content

Commit c06f84d

Browse files
author
Longsheng Du
committed
rebase
1 parent 6a77167 commit c06f84d

File tree

4 files changed

+334
-0
lines changed

4 files changed

+334
-0
lines changed

include/gc/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1717
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
1818
}
1919

20+
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
21+
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
22+
let description = [{
23+
Lowers the `onednn_graph` ops to `linalg` ops.
24+
}];
25+
let dependentDialects = [
26+
"func::FuncDialect",
27+
"math::MathDialect",
28+
"arith::ArithDialect",
29+
"tensor::TensorDialect",
30+
"linalg::LinalgDialect"
31+
];
32+
}
33+
2034
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_library(GCPasses
2+
OneDNNGraphToLinalg.cpp
23
TileNamed.cpp
34

45
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_library(GCPasses
910

1011
LINK_LIBS PUBLIC
1112
${mlir_dialect_libs}
13+
MLIROneDNNGraph
1214
MLIRIR
1315
MLIRSupport
1416
MLIRBufferizationToMemRef
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=//
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <numeric>
11+
#include <vector>
12+
13+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
14+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h"
15+
#include "gc/Transforms/Passes.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
27+
using namespace mlir::onednn_graph;
28+
29+
namespace mlir {
30+
namespace gc {
31+
#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG
32+
#include "gc/Transforms/Passes.h.inc"
33+
34+
namespace {
35+
//===----------------------------------------------------------------------===//
36+
// Util funcs
37+
//===----------------------------------------------------------------------===//
38+
39+
Value createBroadcastOperand(Location loc, PatternRewriter &rewriter,
40+
TensorType ty, Value op) {
41+
auto opTy = dyn_cast<TensorType>(op.getType());
42+
llvm::ArrayRef<int64_t> bcastShape = ty.getShape();
43+
llvm::ArrayRef<int64_t> opShape = opTy.getShape();
44+
int64_t diff = bcastShape.size() - opShape.size();
45+
46+
if (bcastShape.equals(opShape)) {
47+
return op;
48+
} else {
49+
// get broadcast dimensions
50+
llvm::SmallVector<int64_t> bcastDims;
51+
for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) {
52+
int64_t idxOp = i - diff;
53+
if (idxOp < 0) {
54+
bcastDims.push_back(i);
55+
} else if (bcastShape[i] != opShape[idxOp]) {
56+
bcastDims.push_back(i);
57+
}
58+
}
59+
// create a new output tensor
60+
Value initTensor =
61+
rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType());
62+
return rewriter
63+
.create<linalg::BroadcastOp>(
64+
/*location=*/loc,
65+
/*inputs=*/op,
66+
/*inits=*/initTensor,
67+
/*dimensions=*/bcastDims)
68+
.getResults()
69+
.front();
70+
}
71+
}
72+
73+
typedef Value (*GetOperandFn)(Operation *, PatternRewriter &, TensorType);
74+
75+
struct OriginalOperand {
76+
template <unsigned I>
77+
static Value getIdx(Operation *op, PatternRewriter &b, TensorType ty) {
78+
if (I >= op->getNumOperands()) {
79+
op->emitError("Index exceeds operand num.\n");
80+
return nullptr;
81+
}
82+
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I));
83+
}
84+
};
85+
86+
struct ConstantOperand {
87+
template <int64_t I>
88+
static Value getConst(Operation *op, PatternRewriter &b, TensorType ty) {
89+
const auto loc = op->getLoc();
90+
if (llvm::isa<IntegerType>(ty.getElementType())) {
91+
return b.create<arith::ConstantOp>( //
92+
loc, DenseElementsAttr::get(ty, int64_t(I)));
93+
} else if (llvm::isa<FloatType>(ty.getElementType())) {
94+
return b.create<arith::ConstantOp>( //
95+
loc, DenseElementsAttr::get(ty, float(I)));
96+
} else {
97+
op->emitError("Not a supported element type for constant.\n");
98+
return nullptr;
99+
}
100+
}
101+
};
102+
103+
//===----------------------------------------------------------------------===//
104+
// Elemwise lowering
105+
//===----------------------------------------------------------------------===//
106+
107+
// Generate elementwise op using linalg named ops
108+
template <typename LoweredOp>
109+
Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty,
110+
llvm::ArrayRef<Value> inputs) {
111+
// create a new output tensor
112+
Value outTensor =
113+
rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());
114+
115+
auto elemwiseOp = rewriter.create<LoweredOp>(
116+
/*location=*/loc,
117+
/*resultTensorTypes=*/outTensor.getType(),
118+
/*inputs=*/inputs,
119+
/*outputs=*/outTensor);
120+
121+
return elemwiseOp.getResult(0);
122+
}
123+
124+
template <typename UnaryOp, typename LoweredOp, GetOperandFn GetOperand>
125+
struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> {
126+
using OpRewritePattern<UnaryOp>::OpRewritePattern;
127+
LogicalResult matchAndRewrite(UnaryOp op,
128+
PatternRewriter &rewriter) const final {
129+
auto loc = op->getLoc();
130+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
131+
auto inOp = GetOperand(op, rewriter, resultTy);
132+
if (!inOp) {
133+
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
134+
}
135+
auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, {inOp});
136+
rewriter.replaceOp(op, unaryOp);
137+
return success();
138+
}
139+
};
140+
141+
template <typename BinaryOp, typename LoweredOp, GetOperandFn GetOperandLHS,
142+
GetOperandFn GetOperandRHS>
143+
struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> {
144+
using OpRewritePattern<BinaryOp>::OpRewritePattern;
145+
LogicalResult matchAndRewrite(BinaryOp op,
146+
PatternRewriter &rewriter) const final {
147+
auto loc = op->getLoc();
148+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
149+
auto lhsOp = GetOperandLHS(op, rewriter, resultTy);
150+
auto rhsOp = GetOperandRHS(op, rewriter, resultTy);
151+
if (!lhsOp || !rhsOp) {
152+
return rewriter.notifyMatchFailure(op, "Fail to get operand.");
153+
}
154+
auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, //
155+
{lhsOp, rhsOp});
156+
rewriter.replaceOp(op, binaryOp);
157+
return success();
158+
}
159+
};
160+
161+
//===----------------------------------------------------------------------===//
162+
// Op lowering
163+
//===----------------------------------------------------------------------===//
164+
165+
using ReLUOpLowering =
166+
BinaryElemwiseLowering<onednn_graph::ReLUOp, linalg::MaxOp, //
167+
OriginalOperand::getIdx<0>,
168+
ConstantOperand::getConst<0>>;
169+
170+
using AddOpLowering =
171+
BinaryElemwiseLowering<onednn_graph::AddOp, linalg::AddOp, //
172+
OriginalOperand::getIdx<0>,
173+
OriginalOperand::getIdx<1>>;
174+
175+
//===----------------------------------------------------------------------===//
176+
// MatMulOp lowering
177+
//===----------------------------------------------------------------------===//
178+
179+
struct MatMulOpLowering : public OpRewritePattern<MatMulOp> {
180+
using OpRewritePattern<MatMulOp>::OpRewritePattern;
181+
LogicalResult matchAndRewrite(MatMulOp op,
182+
PatternRewriter &rewriter) const final {
183+
auto loc = op->getLoc();
184+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
185+
//
186+
Value zero = rewriter.create<arith::ConstantOp>(
187+
loc, rewriter.getZeroAttr(resultTy.getElementType()));
188+
Value newTensor = rewriter.create<tensor::EmptyOp>(
189+
loc, resultTy.getShape(), resultTy.getElementType());
190+
Value outTensor =
191+
rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);
192+
193+
bool transposeA = op.getTransposeA();
194+
bool transposeB = op.getTransposeB();
195+
Operation *newOp;
196+
if (!transposeA && !transposeB) {
197+
// (A * B)
198+
newOp = rewriter.create<linalg::MatmulOp>(
199+
/*location=*/loc,
200+
/*resultTensorTypes=*/resultTy,
201+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
202+
/*outputs=*/outTensor);
203+
} else if (transposeA && !transposeB) {
204+
// T(A) * B
205+
newOp = rewriter.create<linalg::MatmulTransposeAOp>(
206+
/*location=*/loc,
207+
/*resultTensorTypes=*/resultTy,
208+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
209+
/*outputs=*/outTensor);
210+
} else if (!transposeA && transposeB) {
211+
// A * T(B)
212+
newOp = rewriter.create<linalg::MatmulTransposeBOp>(
213+
/*location=*/loc,
214+
/*resultTensorTypes=*/resultTy,
215+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
216+
/*outputs=*/outTensor);
217+
} else {
218+
// T(B * A)
219+
int64_t rank = resultTy.getRank();
220+
SmallVector<int64_t> permutation(rank);
221+
std::iota(std::begin(permutation), std::end(permutation), 0);
222+
permutation[rank - 2] = rank - 1;
223+
permutation[rank - 1] = rank - 2;
224+
auto matmulOp = rewriter.create<linalg::MatmulOp>(
225+
/*location=*/loc,
226+
/*resultTensorTypes=*/resultTy,
227+
/*inputs=*/ValueRange{op.getInputB(), op.getInputA()},
228+
/*outputs=*/outTensor);
229+
newOp = rewriter.create<linalg::TransposeOp>(
230+
/*location=*/loc,
231+
/*inputs=*/matmulOp.getResult(0),
232+
/*outputs=*/outTensor,
233+
/*permutation=*/permutation);
234+
}
235+
236+
if (op.getBias()) {
237+
auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias());
238+
newOp = rewriter.create<linalg::AddOp>(
239+
/*location=*/loc,
240+
/*resultTensorTypes=*/outTensor.getType(),
241+
/*inputs=*/newOp->getResult(0),
242+
/*outputs=*/bias);
243+
}
244+
245+
rewriter.replaceOp(op, newOp);
246+
return success();
247+
}
248+
};
249+
250+
//===----------------------------------------------------------------------===//
251+
// Pass define
252+
//===----------------------------------------------------------------------===//
253+
254+
struct ConvertOneDNNGraphToLinalg
255+
: public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> {
256+
257+
void runOnOperation() final {
258+
auto *ctx = &getContext();
259+
// add lowering target
260+
ConversionTarget target(getContext());
261+
target.addIllegalDialect<onednn_graph::OneDNNGraphDialect>();
262+
target.addLegalDialect<BuiltinDialect, arith::ArithDialect,
263+
linalg::LinalgDialect, func::FuncDialect,
264+
tensor::TensorDialect>();
265+
// set pattern
266+
RewritePatternSet patterns(ctx);
267+
patterns.add<AddOpLowering>(ctx);
268+
patterns.add<ReLUOpLowering>(ctx);
269+
patterns.add<MatMulOpLowering>(ctx);
270+
// perform conversion
271+
if (failed(
272+
applyFullConversion(getOperation(), target, std::move(patterns)))) {
273+
signalPassFailure();
274+
}
275+
}
276+
};
277+
278+
} // namespace
279+
} // namespace gc
280+
} // namespace mlir
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: gc-opt --split-input-file --convert-onednn-graph-to-linalg %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: @matmul
4+
func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> {
5+
// CHECK: [[C0:%.+]] = arith.constant 0
6+
// CHECK: [[INIT:%.+]] = tensor.empty()
7+
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
8+
// CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
9+
%0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16>
10+
return %0 : tensor<128x256xbf16>
11+
}
12+
13+
// CHECK-LABEL: @add
14+
func.func @add(%arg0: tensor<128x256xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
15+
// CHECK: tensor.empty()
16+
// CHECK: linalg.add
17+
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<128x256xf32>) -> tensor<128x256xf32>
18+
return %0 : tensor<128x256xf32>
19+
}
20+
21+
// CHECK-LABEL: @add_bcast
22+
func.func @add_bcast(%arg0: tensor<128x256xf32>, %arg1: tensor<256xf32>) -> tensor<128x256xf32> {
23+
// CHECK: tensor.empty()
24+
// CHECK: linalg.broadcast
25+
// CHECK: tensor.empty()
26+
// CHECK: linalg.add
27+
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<256xf32>) -> tensor<128x256xf32>
28+
return %0 : tensor<128x256xf32>
29+
}
30+
31+
// CHECK-LABEL: @relu
32+
func.func @relu(%arg0: tensor<128x256xf32>) -> tensor<128x256xf32> {
33+
// CHECK: arith.constant dense<0.0{{.*}}>
34+
// CHECK: tensor.empty()
35+
// CHECK: linalg.max
36+
%0 = onednn_graph.relu %arg0 : (tensor<128x256xf32>) -> tensor<128x256xf32>
37+
return %0 : tensor<128x256xf32>
38+
}

0 commit comments

Comments
 (0)