Skip to content

Commit 5ee1278

Browse files
author
Longsheng Du
committed
add lowering
1 parent 4e629e5 commit 5ee1278

File tree

4 files changed

+306
-0
lines changed

4 files changed

+306
-0
lines changed

include/gc-dialects/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1717
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
1818
}
1919

20+
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
21+
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
22+
let description = [{
23+
Lowers the `onednn_graph` ops to `linalg` ops.
24+
}];
25+
let dependentDialects = [
26+
"func::FuncDialect",
27+
"math::MathDialect",
28+
"arith::ArithDialect",
29+
"tensor::TensorDialect",
30+
"linalg::LinalgDialect"
31+
];
32+
}
33+
2034
#endif // GC_DIALECT_GC_PASSES

lib/gc-dialects/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_library(GCPasses
2+
OneDNNGraphToLinalg.cpp
23
TileNamed.cpp
34

45
ADDITIONAL_HEADER_DIRS
@@ -9,6 +10,7 @@ add_mlir_library(GCPasses
910

1011
LINK_LIBS PUBLIC
1112
${mlir_dialect_libs}
13+
MLIROneDNNGraph
1214
MLIRIR
1315
MLIRSupport
1416
MLIRBufferizationToMemRef
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
//===- OneDNNGraphToLinalg.cpp - OneDNN Graph To Linalg Lowering --*- C++ -*-=//
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include <numeric>
11+
#include <vector>
12+
13+
#include "gc-dialects/OneDNNGraph/OneDNNGraphDialect.h"
14+
#include "gc-dialects/OneDNNGraph/OneDNNGraphOps.h"
15+
#include "gc-dialects/Passes.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25+
26+
using namespace mlir::onednn_graph;
27+
28+
namespace mlir {
29+
namespace gc {
30+
#define GEN_PASS_DEF_CONVERTONEDNNGRAPHTOLINALG
31+
#include "gc-dialects/Passes.h.inc"
32+
33+
namespace {
34+
//===----------------------------------------------------------------------===//
35+
// Util funcs
36+
//===----------------------------------------------------------------------===//
37+
38+
Value createBroadcastOperand(Location loc, PatternRewriter &rewriter,
39+
TensorType ty, Value op) {
40+
auto opTy = dyn_cast<TensorType>(op.getType());
41+
llvm::ArrayRef<int64_t> bcastShape = ty.getShape();
42+
llvm::ArrayRef<int64_t> opShape = opTy.getShape();
43+
int64_t diff = bcastShape.size() - opShape.size();
44+
45+
if (bcastShape.equals(opShape)) {
46+
return op;
47+
} else {
48+
// get broadcast dimensions
49+
llvm::SmallVector<int64_t> bcastDims;
50+
for (int64_t i = 0; i < (int64_t)bcastShape.size(); i++) {
51+
int64_t idxOp = i - diff;
52+
if (idxOp < 0) {
53+
bcastDims.push_back(i);
54+
} else if (bcastShape[i] != opShape[idxOp]) {
55+
bcastDims.push_back(i);
56+
}
57+
}
58+
// create a new output tensor
59+
Value initTensor =
60+
rewriter.create<tensor::EmptyOp>(loc, bcastShape, ty.getElementType());
61+
return rewriter
62+
.create<linalg::BroadcastOp>(
63+
/*location=*/loc,
64+
/*inputs=*/op,
65+
/*inits=*/initTensor,
66+
/*dimensions=*/bcastDims)
67+
.getResults()
68+
.front();
69+
}
70+
}
71+
72+
typedef Value (*OperandGet)(Operation *, PatternRewriter &, TensorType);
73+
74+
template <unsigned I>
75+
Value OriginalOperand(Operation *op, PatternRewriter &b, TensorType ty) {
76+
return createBroadcastOperand(op->getLoc(), b, ty, op->getOperand(I));
77+
}
78+
79+
static Value ConstZeroOperand(Operation *op, PatternRewriter &b,
80+
TensorType ty) {
81+
auto loc = op->getLoc();
82+
Value zero =
83+
b.create<arith::ConstantOp>(loc, b.getZeroAttr(ty.getElementType()));
84+
Value newTensor =
85+
b.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());
86+
return b.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);
87+
}
88+
89+
//===----------------------------------------------------------------------===//
90+
// Elemwise lowering
91+
//===----------------------------------------------------------------------===//
92+
93+
// Generate elementwise op using linalg named ops
94+
template <typename LoweredOp>
95+
Value createElemwiseOp(Location loc, PatternRewriter &rewriter, TensorType ty,
96+
llvm::ArrayRef<Value> inputs) {
97+
// create a new output tensor
98+
Value outTensor =
99+
rewriter.create<tensor::EmptyOp>(loc, ty.getShape(), ty.getElementType());
100+
101+
auto elemwiseOp = rewriter.create<LoweredOp>(
102+
/*location=*/loc,
103+
/*resultTensorTypes=*/outTensor.getType(),
104+
/*inputs=*/inputs,
105+
/*outputs=*/outTensor);
106+
107+
return elemwiseOp.getResult(0);
108+
}
109+
110+
template <typename UnaryOp, typename LoweredOp>
111+
struct UnaryElemwiseLowering : public OpRewritePattern<UnaryOp> {
112+
using OpRewritePattern<UnaryOp>::OpRewritePattern;
113+
LogicalResult matchAndRewrite(UnaryOp op,
114+
PatternRewriter &rewriter) const final {
115+
auto loc = op->getLoc();
116+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
117+
auto unaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, //
118+
{op->getOperand(0)});
119+
rewriter.replaceOp(op, unaryOp);
120+
return success();
121+
}
122+
};
123+
124+
template <typename BinaryOp, typename LoweredOp, OperandGet GetOperandLHS,
125+
OperandGet GetOperandRHS>
126+
struct BinaryElemwiseLowering : public OpRewritePattern<BinaryOp> {
127+
using OpRewritePattern<BinaryOp>::OpRewritePattern;
128+
LogicalResult matchAndRewrite(BinaryOp op,
129+
PatternRewriter &rewriter) const final {
130+
auto loc = op->getLoc();
131+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
132+
auto lhsOp = GetOperandLHS(op, rewriter, resultTy);
133+
auto rhsOp = GetOperandRHS(op, rewriter, resultTy);
134+
auto binaryOp = createElemwiseOp<LoweredOp>(loc, rewriter, resultTy, //
135+
{lhsOp, rhsOp});
136+
rewriter.replaceOp(op, binaryOp);
137+
return success();
138+
}
139+
};
140+
141+
//===----------------------------------------------------------------------===//
142+
// Op lowering
143+
//===----------------------------------------------------------------------===//
144+
145+
using ReLUOpLowering = BinaryElemwiseLowering< //
146+
onednn_graph::ReLUOp, linalg::MaxOp, OriginalOperand<0>, ConstZeroOperand>;
147+
148+
using AddOpLowering = BinaryElemwiseLowering< //
149+
onednn_graph::AddOp, linalg::AddOp, OriginalOperand<0>, OriginalOperand<1>>;
150+
151+
//===----------------------------------------------------------------------===//
152+
// MatMulOp lowering
153+
//===----------------------------------------------------------------------===//
154+
155+
struct MatMulOpLowering : public OpRewritePattern<MatMulOp> {
156+
using OpRewritePattern<MatMulOp>::OpRewritePattern;
157+
LogicalResult matchAndRewrite(MatMulOp op,
158+
PatternRewriter &rewriter) const final {
159+
auto loc = op->getLoc();
160+
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
161+
//
162+
Value newTensor = rewriter.create<tensor::EmptyOp>(
163+
loc, resultTy.getShape(), resultTy.getElementType());
164+
Value zero = rewriter.create<arith::ConstantOp>(
165+
loc, rewriter.getZeroAttr(resultTy.getElementType()));
166+
Value outTensor =
167+
rewriter.create<linalg::FillOp>(loc, zero, newTensor).getResult(0);
168+
169+
bool transposeA = op.getTransposeA();
170+
bool transposeB = op.getTransposeB();
171+
Operation *newOp;
172+
if (!transposeA && !transposeB) {
173+
// (A * B)
174+
newOp = rewriter.create<linalg::MatmulOp>(
175+
/*location=*/loc,
176+
/*resultTensorTypes=*/resultTy,
177+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
178+
/*outputs=*/outTensor);
179+
} else if (transposeA && !transposeB) {
180+
// T(A) * B
181+
newOp = rewriter.create<linalg::MatmulTransposeAOp>(
182+
/*location=*/loc,
183+
/*resultTensorTypes=*/resultTy,
184+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
185+
/*outputs=*/outTensor);
186+
} else if (!transposeA && transposeB) {
187+
// A * T(B)
188+
newOp = rewriter.create<linalg::MatmulTransposeBOp>(
189+
/*location=*/loc,
190+
/*resultTensorTypes=*/resultTy,
191+
/*inputs=*/ValueRange{op.getInputA(), op.getInputB()},
192+
/*outputs=*/outTensor);
193+
} else {
194+
// T(B * A)
195+
int64_t rank = resultTy.getRank();
196+
SmallVector<int64_t> permutation(rank);
197+
std::iota(std::begin(permutation), std::end(permutation), 0);
198+
permutation[rank - 2] = rank - 1;
199+
permutation[rank - 1] = rank - 2;
200+
auto matmulOp = rewriter.create<linalg::MatmulOp>(
201+
/*location=*/loc,
202+
/*resultTensorTypes=*/resultTy,
203+
/*inputs=*/ValueRange{op.getInputB(), op.getInputA()},
204+
/*outputs=*/outTensor);
205+
newOp = rewriter.create<linalg::TransposeOp>(
206+
/*location=*/loc,
207+
/*inputs=*/matmulOp.getResult(0),
208+
/*outputs=*/outTensor,
209+
/*permutation=*/permutation);
210+
}
211+
212+
if (op.getBias()) {
213+
auto bias = createBroadcastOperand(loc, rewriter, resultTy, op.getBias());
214+
newOp = rewriter.create<linalg::AddOp>(
215+
/*location=*/loc,
216+
/*resultTensorTypes=*/outTensor.getType(),
217+
/*inputs=*/newOp->getResult(0),
218+
/*outputs=*/bias);
219+
}
220+
221+
rewriter.replaceOp(op, newOp);
222+
return success();
223+
}
224+
};
225+
226+
//===----------------------------------------------------------------------===//
227+
// Pass define
228+
//===----------------------------------------------------------------------===//
229+
230+
struct ConvertOneDNNGraphToLinalg
231+
: public impl::ConvertOneDNNGraphToLinalgBase<ConvertOneDNNGraphToLinalg> {
232+
233+
void runOnOperation() final {
234+
//
235+
auto *ctx = &getContext();
236+
RewritePatternSet patterns(ctx);
237+
patterns.add<AddOpLowering>(ctx);
238+
patterns.add<ReLUOpLowering>(ctx);
239+
patterns.add<MatMulOpLowering>(ctx);
240+
//
241+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
242+
std::move(patterns)))) {
243+
signalPassFailure();
244+
}
245+
}
246+
};
247+
248+
} // namespace
249+
} // namespace gc
250+
} // namespace mlir
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: gc-opt --split-input-file -pass-pipeline="builtin.module(func.func(convert-onednn-graph-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: @matmul
4+
func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> {
5+
// CHECK: [[C0:%.+]] = arith.constant 0
6+
// CHECK: [[INIT:%.+]] = tensor.empty()
7+
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
8+
// CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
9+
%0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16>
10+
return %0 : tensor<128x256xbf16>
11+
}
12+
13+
// CHECK-LABEL: @add
14+
func.func @add(%arg0: tensor<128x256xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> {
15+
// CHECK: tensor.empty()
16+
// CHECK: linalg.add
17+
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<128x256xf32>) -> tensor<128x256xf32>
18+
return %0 : tensor<128x256xf32>
19+
}
20+
21+
// CHECK-LABEL: @add_bcast
22+
func.func @add_bcast(%arg0: tensor<128x256xf32>, %arg1: tensor<256xf32>) -> tensor<128x256xf32> {
23+
// CHECK: tensor.empty()
24+
// CHECK: linalg.broadcast
25+
// CHECK: tensor.empty()
26+
// CHECK: linalg.add
27+
%0 = onednn_graph.add %arg0, %arg1 : (tensor<128x256xf32>, tensor<256xf32>) -> tensor<128x256xf32>
28+
return %0 : tensor<128x256xf32>
29+
}
30+
31+
// CHECK-LABEL: @relu
32+
func.func @relu(%arg0: tensor<128x256xf32>) -> tensor<128x256xf32> {
33+
// CHECK: arith.constant 0
34+
// CHECK: tensor.empty()
35+
// CHECK: linalg.fill
36+
// CHECK: tensor.empty()
37+
// CHECK: linalg.max
38+
%0 = onednn_graph.relu %arg0 : (tensor<128x256xf32>) -> tensor<128x256xf32>
39+
return %0 : tensor<128x256xf32>
40+
}

0 commit comments

Comments
 (0)