Skip to content

Commit d7c3c0b

Browse files
author
Longsheng Du
authored
[Dialect] [OneDNNGraph] Add ops lowering for llama2 mlp (#107)
1 parent 43a83fc commit d7c3c0b

File tree

8 files changed

+626
-108
lines changed

8 files changed

+626
-108
lines changed

include/gc/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class OpenMPDialect;
2828
namespace linalg {
2929
class LinalgDialect;
3030
}
31+
namespace linalgx {
32+
class LinalgxDialect;
33+
}
3134

3235
namespace MemRef {
3336
class MemRefDialect;

include/gc/Transforms/Passes.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
2323
Lowers the `onednn_graph` ops to `linalg` ops.
2424
}];
2525
let dependentDialects = [
26-
"func::FuncDialect",
27-
"math::MathDialect",
28-
"arith::ArithDialect",
29-
"tensor::TensorDialect",
30-
"linalg::LinalgDialect"
26+
"func::FuncDialect",
27+
"math::MathDialect",
28+
"arith::ArithDialect",
29+
"tensor::TensorDialect",
30+
"linalg::LinalgDialect",
31+
"linalgx::LinalgxDialect"
3132
];
3233
}
3334

@@ -37,6 +38,7 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
3738
"tensor::TensorDialect",
3839
"memref::MemRefDialect",
3940
"linalg::LinalgDialect",
41+
"linalgx::LinalgxDialect",
4042
"LLVM::LLVMDialect",
4143
"scf::SCFDialect",
4244
"bufferization::BufferizationDialect",

lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp

Lines changed: 85 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ BINARY_OP_SHAPE_INFER(onednn_graph::DivOp)
4848
// Reduce ops shape infer
4949
//===----------------------------------------------------------------------===//
5050

51+
// canonicalize reduced axes
52+
// make all axis inside reduced axes array non-negative in acceding order
5153
SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
5254
int64_t rank) {
5355
SmallVector<int64_t> ret(axes.size());
@@ -59,21 +61,46 @@ SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
5961
return ret;
6062
}
6163

62-
SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
63-
ArrayRef<int64_t> axes, bool keep_dims) {
64-
SmallVector<int64_t> outputShape;
64+
// canonicalize kept axes
65+
// make all axis inside kept axes array non-negative in acceding order
66+
SmallVector<int64_t> canonicalizeKeepAxes(ArrayRef<int64_t> axes, int64_t rank,
67+
bool canonicalized = false) {
68+
// get canonicalized reduce axes
69+
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{}
70+
: canonicalizeReduceAxes(axes, rank);
71+
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized);
72+
// get kept axes
73+
SmallVector<int64_t> keepAxes;
74+
for (int64_t dim = 0, idx = 0; dim < rank; dim++) {
75+
if (idx < (int64_t)reduceAxes.size() && reduceAxes[idx] == dim) {
76+
idx++;
77+
continue;
78+
}
79+
keepAxes.push_back(dim);
80+
}
81+
return keepAxes;
82+
}
83+
84+
SmallVector<int64_t> inferReducedShape(ShapedType operandShape,
85+
ArrayRef<int64_t> axes, bool keepDims,
86+
bool canonicalized = false) {
87+
// get canonicalized reduce axes
88+
auto rank = operandShape.getRank();
89+
auto newCanonicalized = canonicalized ? SmallVector<int64_t>{}
90+
: canonicalizeReduceAxes(axes, rank);
91+
auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t>(newCanonicalized);
6592
// get reduce axis one by one
6693
size_t index = 0;
6794
auto getNextReduceAxis = [&]() {
68-
return (index >= axes.size()) ? -1 : axes[index++];
95+
return (index >= reduceAxes.size()) ? -1 : reduceAxes[index++];
6996
};
7097
// get reduced shape
71-
auto rank = operandShape.getRank();
98+
SmallVector<int64_t> outputShape;
7299
auto axis = getNextReduceAxis();
73100
for (int64_t idx = 0; idx < rank; idx++) {
74101
if (idx == axis) {
75102
axis = getNextReduceAxis();
76-
if (keep_dims) {
103+
if (keepDims) {
77104
outputShape.push_back(1);
78105
}
79106
} else {
@@ -84,16 +111,16 @@ SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
84111
}
85112

86113
static LogicalResult InferReduceReturnTypes(
87-
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes,
88-
bool keep_dims,
114+
ShapedType operandTy, ArrayRef<int64_t> axes, bool keepDims,
89115
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
90116
// no reduce axes
91117
if (axes.empty()) {
92-
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape));
118+
inferredReturnShapes.push_back(ShapedTypeComponents(operandTy));
93119
return success();
94120
}
95-
inferredReturnShapes.push_back(ShapedTypeComponents(
96-
getReducedShape(operandShape, axes, keep_dims), elemType));
121+
inferredReturnShapes.push_back(
122+
ShapedTypeComponents(inferReducedShape(operandTy, axes, keepDims),
123+
operandTy.getElementType()));
97124
return success();
98125
}
99126

@@ -119,10 +146,10 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
119146
return failure();
120147
}
121148
// canonicalize the reduce axes
122-
auto new_axes = canonicalizeReduceAxes(op.getAxes(), rank);
123-
auto new_op = rewriter.create<ReduceOp>(
124-
op.getLoc(), op.getType(), op.getOperand(), new_axes, op.getKeepDims());
125-
rewriter.replaceOp(op, new_op);
149+
auto newAxes = canonicalizeReduceAxes(op.getAxes(), rank);
150+
auto newOp = rewriter.create<ReduceOp>(
151+
op.getLoc(), op.getType(), op.getOperand(), newAxes, op.getKeepDims());
152+
rewriter.replaceOp(op, newOp);
126153
// NOLINTEND
127154
return success();
128155
}
@@ -142,12 +169,9 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
142169
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
143170
llvm::SmallVector<int64_t> outShape; \
144171
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \
145-
auto rank = operandTy.getRank(); \
146-
ShapeAdaptor inputShape(operandTy); \
147-
return InferReduceReturnTypes( \
148-
inputShape, operandTy.getElementType(), \
149-
canonicalizeReduceAxes(adaptor.getAxes(), rank), \
150-
adaptor.getKeepDims(), inferredReturnShapes); \
172+
return InferReduceReturnTypes(operandTy, adaptor.getAxes(), \
173+
adaptor.getKeepDims(), \
174+
inferredReturnShapes); \
151175
}
152176

153177
#define REDUCE_OP_VERIFY(OP) \
@@ -181,22 +205,37 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
181205
MLIRContext *context, ::std::optional<Location> location,
182206
MatMulOp::Adaptor adaptor,
183207
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
184-
// get batch dims from shape
185-
auto extractBatch = [](const ShapeAdaptor &lhsShape,
186-
const ShapeAdaptor &rhsShape, int64_t range,
187-
int64_t diff, SmallVector<int64_t> &outDims) {
188-
for (int64_t i = 0; i < range; i++) {
189-
// TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
190-
int64_t idx = i - diff;
191-
if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) {
192-
return failure();
193-
}
194-
outDims.push_back(lhsShape.getDimSize(i));
208+
// get batch dims from 1 multi-batch mat shape
209+
auto extractBatch = [](ShapedType shape, SmallVector<int64_t> &outDims) {
210+
// assuming last 2 input dims are row and col
211+
assert(shape.getRank() >= 2);
212+
for (int64_t i = 0; i < shape.getRank() - 2; i++) {
213+
outDims.push_back(shape.getDimSize(i));
195214
}
196215
return success();
197216
};
217+
// get broadcasted batch dims from 2 multi-batch mat shape,
218+
auto extractBroadcastBatch = [](ShapedType lhsType, ShapedType rhsType,
219+
SmallVector<int64_t> &outDims) {
220+
SmallVector<int64_t> lhsShape(lhsType.getShape());
221+
SmallVector<int64_t> rhsShape(rhsType.getShape());
222+
assert(lhsShape.size() >= 2 && rhsShape.size() >= 2);
223+
// assuming last 2 input dims are row and col
224+
// 0xFF is just a random number > 1, replacing the row and col dims
225+
// so that getBroadcastedShape can match, will be removed after
226+
lhsShape[lhsShape.size() - 1] = 0xFF;
227+
lhsShape[lhsShape.size() - 2] = 0xFF;
228+
rhsShape[rhsShape.size() - 1] = 0xFF;
229+
rhsShape[rhsShape.size() - 2] = 0xFF;
230+
bool ret = OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, outDims);
231+
// remove 0xFF
232+
assert(outDims.size() >= 2);
233+
outDims.pop_back();
234+
outDims.pop_back();
235+
return LogicalResult::success(ret);
236+
};
198237
// get row col of 2d matrix according to transpose info
199-
auto getMatRowCol = [](const ShapeAdaptor &shape, bool transpose) {
238+
auto getMatRowCol = [](ShapedType shape, bool transpose) {
200239
using pairRowCol = std::pair<int64_t, int64_t>;
201240
auto rank = shape.getRank();
202241
assert(rank > 1);
@@ -205,8 +244,8 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
205244
: pairRowCol(shape.getDimSize(rank - 2),
206245
shape.getDimSize(rank - 1));
207246
};
208-
ShapeAdaptor lhsShape(adaptor.getInputA().getType());
209-
ShapeAdaptor rhsShape(adaptor.getInputB().getType());
247+
auto lhsShape = cast<ShapedType>(adaptor.getInputA().getType());
248+
auto rhsShape = cast<ShapedType>(adaptor.getInputB().getType());
210249
bool transposeA = adaptor.getTransposeA();
211250
bool transposeB = adaptor.getTransposeB();
212251
int64_t lRank = lhsShape.getRank();
@@ -223,36 +262,24 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
223262
} else if (lRank == 1 && rRank > 1) {
224263
// 1D x ND
225264
auto rMatRowCol = getMatRowCol(rhsShape, transposeB);
226-
status = extractBatch(rhsShape, rhsShape, rRank - 2, 0, outShape);
265+
status = extractBatch(rhsShape, outShape);
227266
if (lhsShape.getDimSize(0) != rMatRowCol.first) {
228267
return failure();
229268
}
230269
outShape.push_back(rhsShape.getDimSize(rMatRowCol.second));
231270
} else if (lRank > 1 && rRank == 1) {
232271
// ND x 1D
233272
auto lMatRowCol = getMatRowCol(lhsShape, transposeA);
234-
status = extractBatch(lhsShape, lhsShape, lRank - 2, 0, outShape);
273+
status = extractBatch(lhsShape, outShape);
235274
if (lMatRowCol.second != rhsShape.getDimSize(0)) {
236275
return failure();
237276
}
238277
outShape.push_back(lhsShape.getDimSize(lMatRowCol.first));
239278
} else if (lRank > 1 && rRank > 1) {
240-
if (lRank == rRank) {
241-
// ND x ND
242-
auto range = lRank - 2;
243-
status = extractBatch(lhsShape, rhsShape, range, 0, outShape);
244-
} else if (lRank > rRank) {
245-
// MD x ND (M > N)
246-
auto range = lRank - 2;
247-
auto diff = lRank - rRank;
248-
status = extractBatch(lhsShape, rhsShape, range, diff, outShape);
249-
} else {
250-
// ND x MD (M > N)
251-
auto range = rRank - 2;
252-
auto diff = rRank - lRank;
253-
status = extractBatch(rhsShape, lhsShape, range, diff, outShape);
254-
}
255-
//
279+
// ND x ND
280+
// MD x ND (M > N)
281+
// ND x MD (M > N)
282+
status = extractBroadcastBatch(lhsShape, rhsShape, outShape);
256283
auto lMatRowCol = getMatRowCol(lhsShape, transposeA);
257284
auto rMatRowCol = getMatRowCol(rhsShape, transposeB);
258285
if (failed(status) || (lMatRowCol.second != rMatRowCol.first)) {
@@ -269,16 +296,13 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
269296
inferredReturnShapes.push_back(retShape);
270297
// check for bias broadcasting
271298
if (adaptor.getBias()) {
272-
auto biasType = adaptor.getBias().getType();
273-
ShapeAdaptor biasShape(biasType);
274-
275-
bool biasRankMatch = biasShape.getRank() == 1 ||
276-
biasShape.getRank() == (int64_t)outShape.size();
299+
auto biasType = dyn_cast<ShapedType>(adaptor.getBias().getType());
300+
bool biasRankMatch = biasType.getRank() == 1 ||
301+
biasType.getRank() == (int64_t)outShape.size();
277302
SmallVector<int64_t> resultShape;
278303
if (!biasRankMatch ||
279-
!OpTrait::util::getBroadcastedShape(
280-
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(),
281-
resultShape)) {
304+
!OpTrait::util::getBroadcastedShape(retShape.getDims(),
305+
biasType.getShape(), resultShape)) {
282306
return failure();
283307
}
284308
}

0 commit comments

Comments
 (0)