Skip to content

Commit d6a2e7a

Browse files
author
Longsheng Du
committed
add
1 parent 0b57b20 commit d6a2e7a

File tree

1 file changed

+293
-0
lines changed

1 file changed

+293
-0
lines changed
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=//
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <numeric>
11+
#include <vector>
12+
13+
#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h"
14+
#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h"
15+
#include "gc-dialects/Passes.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
26+
using namespace mlir::onednn_graph;
27+
28+
namespace mlir {
29+
namespace gc {
30+
#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG
31+
#include "gc-dialects/Passes.h.inc"
32+
33+
namespace {
34+
//===----------------------------------------------------------------------===//
35+
// Util funcs
36+
//===----------------------------------------------------------------------===//
37+
38+
// Create lowered elementwise Op
39+
struct CreateElementwiseOp {
40+
virtual ~CreateElementwiseOp() = default;
41+
virtual Value create(OpBuilder &b, Location loc, ValueRange args) const = 0;
42+
};
43+
44+
// Generate elementwise op using linalg::GenericOp
45+
Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty,
46+
llvm::ArrayRef<Value> inputs,
47+
const CreateElementwiseOp &createOp) {
48+
// create indexing maps for elemwise_op as all identity maps
49+
llvm::SmallVector<AffineMap> indexingMaps( //
50+
inputs.size() + 1, //
51+
rewriter.getMultiDimIdentityMap(ty.getRank()));
52+
// create iterator types all "parallel", no axis for "reduction"
53+
llvm::SmallVector<utils::IteratorType> iteratorTypes( //
54+
ty.getRank(), //
55+
utils::IteratorType::parallel);
56+
57+
// create a new output tensor
58+
Value outTensor =
59+
rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());
60+
61+
auto elemwiseOp = rewriter.create<linalg::GenericOp>(
62+
/*location=*/loc,
63+
/*resultTensorTypes=*/outTensor.getType(),
64+
/*inputs=*/inputs,
65+
/*outputs=*/outTensor,
66+
/*indexingMaps=*/indexingMaps,
67+
/*iteratorTypes=*/iteratorTypes,
68+
[&](OpBuilder &b, Location loc, ValueRange args) {
69+
Value result = createOp.create(b, loc, args);
70+
b.create<linalg::YieldOp>(loc, result);
71+
});
72+
73+
return elemwiseOp.getResult(0);
74+
}
75+
76+
Value createBroadcastOperand(Location loc, PatternRewriter &rewriter,
77+
TensorType ty, Value op) {
78+
auto opTy = dyn_cast<TensorType>(op.getType());
79+
llvm::ArrayRef<int64_t> bcastShape = ty.getShape();
80+
llvm::ArrayRef<int64_t> opShape = opTy.getShape();
81+
int64_t diff = bcastShape.size() - opShape.size();
82+
//
83+
if (bcastShape.equals(opShape)) {
84+
return op;
85+
} else {
86+
// get broadcast dimensions
87+
llvm::SmallVector<int64_t> bcastDims;
88+
for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) {
89+
int64_t idxOp = i - diff;
90+
if (idxOp < 0) {
91+
bcastDims.push_back(i);
92+
} else if (bcastShape[i] != opShape[idxOp]) {
93+
bcastDims.push_back(i);
94+
}
95+
}
96+
// create a new output tensor
97+
Value initTensor =
98+
rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType());
99+
return rewriter
100+
.create<linalg::BroadcastOp>(
101+
/*location=*/loc,
102+
/*inputs=*/op,
103+
/*inits=*/initTensor,
104+
/*dimensions=*/bcastDims)
105+
.getResults()
106+
.front();
107+
}
108+
}
109+
110+
//===----------------------------------------------------------------------===//
111+
// UnaryOp lowering
112+
//===----------------------------------------------------------------------===//
113+
114+
template <typename UnaryOp, typename CreateLoweredOp>
115+
struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> {
116+
using OpRewritePattern<UnaryOp>::OpRewritePattern;
117+
LogicalResult matchAndRewrite(UnaryOp op,
118+
PatternRewriter &rewriter) const final {
119+
auto loc = op->getLoc();
120+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
121+
//
122+
auto unaryOp = createElemwiseOp(loc, rewriter, resultTy,
123+
{op->getOperand(0)}, CreateLoweredOp());
124+
rewriter.replaceOp(op, unaryOp);
125+
return success();
126+
}
127+
};
128+
129+
template <typename LoweredOp>
130+
struct CreateLoweredUnaryOp : public CreateElementwiseOp {
131+
Value create(OpBuilder &b, Location loc, ValueRange args) const final {
132+
return b.create<LoweredOp>(loc, args[0]);
133+
}
134+
};
135+
136+
struct CreateLoweredReLUOp : public CreateElementwiseOp {
137+
Value create(OpBuilder &b, Location loc, ValueRange args) const final {
138+
Value input = args[0];
139+
Value zeros =
140+
b.create<arith::ConstantOp>(loc, FloatAttr::get(input.getType(), 0.f));
141+
return b.create<arith::MaximumFOp>(loc, input, zeros);
142+
}
143+
};
144+
145+
//===----------------------------------------------------------------------===//
146+
// BinaryOp lowering
147+
//===----------------------------------------------------------------------===//
148+
149+
template <typename BinaryOp, typename CreateLoweredOp>
150+
struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> {
151+
using OpRewritePattern<BinaryOp>::OpRewritePattern;
152+
LogicalResult matchAndRewrite(BinaryOp op,
153+
PatternRewriter &rewriter) const final {
154+
auto loc = op->getLoc();
155+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
156+
//
157+
auto lhsOp =
158+
createBroadcastOperand(loc, rewriter, resultTy, op->getOperand(0));
159+
auto rhsOp =
160+
createBroadcastOperand(loc, rewriter, resultTy, op->getOperand(1));
161+
//
162+
auto binaryOp = createElemwiseOp(loc, rewriter, resultTy, {lhsOp, rhsOp},
163+
CreateLoweredOp());
164+
rewriter.replaceOp(op, binaryOp);
165+
return success();
166+
}
167+
};
168+
169+
template <typename LoweredOp>
170+
struct CreateLoweredBinaryOp : public CreateElementwiseOp {
171+
Value create(OpBuilder &b, Location loc, ValueRange args) const final {
172+
return b.create<LoweredOp>(loc, args[0], args[1]);
173+
}
174+
};
175+
176+
//===----------------------------------------------------------------------===//
177+
// Op lowering
178+
//===----------------------------------------------------------------------===//
179+
180+
using ReLUOpLowering = UnaryElemwiseLowering< //
181+
onednn_graph::ReLUOp, CreateLoweredReLUOp>;
182+
// using ExpOpLowering = UnaryElemwiseLowering< //
183+
// onednn_graph::ExpOp, CreateLoweredUnaryOp<math::ExpOp>>;
184+
185+
using AddOpLowering = BinaryElemwiseLowering< //
186+
onednn_graph::AddOp, CreateLoweredBinaryOp<arith::AddFOp>>;
187+
// using SubOpLowering = BinaryElemwiseLowering< //
188+
// onednn_graph::SubOp, CreateLoweredBinaryOp<arith::SubFOp>>;
189+
// using MulOpLowering = BinaryElemwiseLowering< //
190+
// onednn_graph::MulOp, CreateLoweredBinaryOp<arith::MulFOp>>;
191+
// using DivOpLowering = BinaryElemwiseLowering< //
192+
// onednn_graph::DivOp, CreateLoweredBinaryOp<arith::DivFOp>>;
193+
194+
//===----------------------------------------------------------------------===//
195+
// MatMulOp lowering
196+
//===----------------------------------------------------------------------===//
197+
198+
struct MatMulOpLowering : public OpRewritePattern<MatMulOp> {
199+
using OpRewritePattern<MatMulOp>::OpRewritePattern;
200+
LogicalResult matchAndRewrite(MatMulOp op,
201+
PatternRewriter &rewriter) const final {
202+
auto loc = op->getLoc();
203+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
204+
//
205+
Value newTensor = rewriter.create<tensor::EmptyOp>(
206+
loc, resultTy.getShape(), resultTy.getElementType());
207+
Value zero = rewriter.create<arith::ConstantOp>(
208+
loc, rewriter.getZeroAttr(resultTy.getElementType()));
209+
Value outTensor =
210+
rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);
211+
212+
bool transposeA = op.getTransposeA();
213+
bool transposeB = op.getTransposeB();
214+
Operation *newOp;
215+
if (!transposeA && !transposeB) {
216+
// (A * B)
217+
newOp = rewriter.create<linalg::MatmulOp>(
218+
/*location=*/loc,
219+
/*resultTensorTypes=*/resultTy,
220+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
221+
/*outputs=*/outTensor);
222+
} else if (transposeA && !transposeB) {
223+
// T(A) * B
224+
newOp = rewriter.create<linalg::MatmulTransposeAOp>(
225+
/*location=*/loc,
226+
/*resultTensorTypes=*/resultTy,
227+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
228+
/*outputs=*/outTensor);
229+
} else if (!transposeA && transposeB) {
230+
// A * T(B)
231+
newOp = rewriter.create<linalg::MatmulTransposeBOp>(
232+
/*location=*/loc,
233+
/*resultTensorTypes=*/resultTy,
234+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
235+
/*outputs=*/outTensor);
236+
} else {
237+
// T(B * A)
238+
int64_t rank = resultTy.getRank();
239+
SmallVector<int64_t> permutation(rank);
240+
std::iota(std::begin(permutation), std::end(permutation), 0);
241+
permutation[rank - 2] = rank - 1;
242+
permutation[rank - 1] = rank - 2;
243+
auto matmulOp = rewriter.create<linalg::MatmulOp>(
244+
/*location=*/loc,
245+
/*resultTensorTypes=*/resultTy,
246+
/*inputs=*/ValueRange{op.getInputB(), op.getInputA()},
247+
/*outputs=*/outTensor);
248+
newOp = rewriter.create<linalg::TransposeOp>(
249+
/*location=*/loc,
250+
/*inputs=*/matmulOp.getResult(0),
251+
/*outputs=*/outTensor,
252+
/*permutation=*/permutation);
253+
}
254+
255+
if (op.getBias()) {
256+
auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias());
257+
newOp = rewriter.create<linalg::AddOp>(
258+
/*location=*/loc,
259+
/*resultTensorTypes=*/outTensor.getType(),
260+
/*inputs=*/newOp->getResult(0),
261+
/*outputs=*/bias);
262+
}
263+
264+
rewriter.replaceOp(op, newOp);
265+
return success();
266+
}
267+
};
268+
269+
//===----------------------------------------------------------------------===//
270+
// Pass define
271+
//===----------------------------------------------------------------------===//
272+
273+
struct ConvertOneDNNGraphToLinalg
274+
: public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> {
275+
276+
void runOnOperation() final {
277+
//
278+
auto *ctx = &getContext();
279+
RewritePatternSet patterns(ctx);
280+
patterns.add<AddOpLowering>(ctx);
281+
patterns.add<ReLUOpLowering>(ctx);
282+
patterns.add<MatMulOpLowering>(ctx);
283+
//
284+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
285+
std::move(patterns)))) {
286+
signalPassFailure();
287+
}
288+
}
289+
};
290+
291+
} // namespace
292+
} // namespace gc
293+
} // namespace mlir

0 commit comments

Comments
 (0)