Skip to content

Commit 8eeda34

Browse files
author
Longsheng Du
authored
[Dialect] [OneDNNGraph] Add onednn_graph ops for llama2 mlp (#92)
1 parent c05bdfb commit 8eeda34

File tree

3 files changed

+392
-28
lines changed

3 files changed

+392
-28
lines changed

include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td

Lines changed: 135 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,38 @@ include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td"
2424
class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
2525
Op<OneDNNGraphDialect, mnemonic, traits>;
2626

27+
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
28+
OneDNNGraph_Op<mnemonic, traits #
29+
[SameOperandsAndResultType]> {
30+
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
31+
let results = (outs OneDNNGraph_FloatTensor:$result);
32+
33+
let assemblyFormat =
34+
"operands attr-dict `:` functional-type(operands, results)";
35+
}
36+
2737
class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
28-
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType,
29-
ResultsBroadcastableShape]> {
30-
let arguments = (ins OneDNNGraph_FloatTensor:$input_0,
31-
OneDNNGraph_FloatTensor:$input_1);
38+
OneDNNGraph_Op<mnemonic, traits #
39+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor, ResultsBroadcastableShape]> {
40+
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
41+
OneDNNGraph_FloatTensor:$input_b,
42+
DefaultValuedOptionalAttr<BoolAttr, "true">:$auto_broadcast);
3243
let results = (outs OneDNNGraph_FloatTensor:$result);
3344

3445
let assemblyFormat =
3546
"operands attr-dict `:` functional-type(operands, results)";
3647
}
3748

38-
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
39-
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
40-
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
49+
class OneDNNGraph_ReduceOp<string mnemonic, list<Trait> traits = []> :
50+
OneDNNGraph_Op<mnemonic, traits #
51+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
52+
let arguments = (ins OneDNNGraph_FloatTensor:$operand,
53+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$axes,
54+
DefaultValuedOptionalAttr<BoolAttr, "false">:$keep_dims);
4155
let results = (outs OneDNNGraph_FloatTensor:$result);
4256

57+
let hasVerifier = 1;
58+
let hasCanonicalizer = 1;
4359
let assemblyFormat =
4460
"operands attr-dict `:` functional-type(operands, results)";
4561
}
@@ -48,36 +64,142 @@ class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
4864
// OneDNNGraph op definitions
4965
//===----------------------------------------------------------------------===//
5066

67+
// Matmul
68+
5169
def OneDNNGraph_MatMulOp :
52-
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
53-
let summary = "Generalized matrix multiplication";
70+
OneDNNGraph_Op<"matmul",
71+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
72+
let summary = [{
73+
MatMul operation computes the product of two tensors with optional bias addition.
74+
}];
5475
let description = [{
5576
`https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html`
5677
}];
5778

5879
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
5980
OneDNNGraph_FloatTensor:$input_b,
60-
Optional<OneDNNGraph_LogicalTensor>:$bias,
61-
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
62-
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
81+
Optional<OneDNNGraph_FloatTensor>:$bias,
82+
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_a,
83+
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_b);
6384
let results = (outs OneDNNGraph_FloatTensor:$result);
6485

6586
let assemblyFormat =
6687
"operands attr-dict `:` functional-type(operands, results)";
6788
}
6889

90+
// Common Unary
91+
6992
def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
7093
let summary = "element-wise relu";
7194
let description = [{
7295
`https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html`
7396
}];
7497
}
7598

99+
def OneDNNGraph_SigmoidOp : OneDNNGraph_ElemwiseUnaryOp<"sigmoid"> {
100+
let summary = "element-wise sigmoid";
101+
let description = [{
102+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_sigmoid.html`
103+
}];
104+
}
105+
106+
// Special Unary
107+
108+
def OneDNNGraph_TypeCastOp : OneDNNGraph_Op<"type_cast", [SameOperandsAndResultShape]> {
109+
let summary = [{
110+
TypeCast operation performs element-wise cast from input data type
111+
to the data type given by output tensor.
112+
}];
113+
let description = [{
114+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_typecast.html`
115+
}];
116+
117+
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
118+
let results = (outs OneDNNGraph_FloatTensor:$result);
119+
120+
let assemblyFormat =
121+
"operands attr-dict `:` functional-type(operands, results)";
122+
}
123+
124+
def OneDNNGraph_PowOp : OneDNNGraph_Op<"pow", [SameOperandsAndResultType]> {
125+
let summary = [{
126+
Pow operation performs an element-wise power operation on a given input
127+
tensor with a single value attribute beta as its exponent.
128+
}];
129+
let description = [{
130+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_pow.html`
131+
}];
132+
133+
let arguments = (ins OneDNNGraph_FloatTensor:$operand,
134+
F32Attr:$beta);
135+
let results = (outs OneDNNGraph_FloatTensor:$result);
136+
137+
let assemblyFormat =
138+
"operands attr-dict `:` functional-type(operands, results)";
139+
}
140+
141+
// Common Binary
142+
76143
def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
77-
let summary = "element-wise addition with multi-directional broadcast";
144+
let summary = [{
145+
Add operation performs element-wise addition operation with two
146+
given tensors applying multi-directional broadcast rules.
147+
}];
78148
let description = [{
79149
`https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html`
80150
}];
81151
}
82152

153+
def OneDNNGraph_MulOp : OneDNNGraph_ElemwiseBinaryOp<"mul", [Commutative]> {
154+
let summary = [{
155+
Multiply operation performs element-wise multiply operation with two
156+
given tensors applying multi-directional broadcast rules.
157+
}];
158+
let description = [{
159+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_multiply.html`
160+
}];
161+
}
162+
163+
def OneDNNGraph_SubOp : OneDNNGraph_ElemwiseBinaryOp<"sub"> {
164+
let summary = [{
165+
Subtract operation performs element-wise subtraction operation with
166+
two given tensors applying multi-directional broadcast rules.
167+
}];
168+
let description = [{
169+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_subtract.html`
170+
}];
171+
}
172+
173+
def OneDNNGraph_DivOp : OneDNNGraph_ElemwiseBinaryOp<"div"> {
174+
let summary = [{
175+
Divide operation performs element-wise division operation with two
176+
given tensors applying multi-directional broadcast rules.
177+
}];
178+
let description = [{
179+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_divide.html`
180+
}];
181+
}
182+
183+
// Common Reduce
184+
185+
def OneDNNGraph_ReduceSumOp : OneDNNGraph_ReduceOp<"reduce_sum"> {
186+
let summary = [{
187+
ReduceSum operation performs the reduction with addition on a given
188+
src data along dimensions specified by axes.
189+
}];
190+
let description = [{
191+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducesum.html`
192+
}];
193+
}
194+
195+
def OneDNNGraph_ReduceMeanOp : OneDNNGraph_ReduceOp<"reduce_mean"> {
196+
let summary = [{
197+
ReduceMean operation performs the reduction with finding the arithmetic
198+
mean on a given src data along dimensions specified by axes.
199+
}];
200+
let description = [{
201+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducemean.html`
202+
}];
203+
}
204+
83205
#endif // ONEDNNGRAPH_OPS

lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp

Lines changed: 158 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,166 @@
1717
namespace mlir {
1818
namespace onednn_graph {
1919

20-
LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
21-
MLIRContext *context, ::std::optional<Location> location,
22-
ValueShapeRange operands, DictionaryAttr attributes,
23-
OpaqueProperties properties, RegionRange regions,
24-
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
25-
llvm::SmallVector<int64_t> outShape;
26-
auto resultTy = dyn_cast<ShapedType>(operands.front().getType());
27-
auto getShapeIdx = [&operands](size_t i) {
28-
return operands.getTypes()[i].dyn_cast<ShapedType>().getShape();
20+
//===----------------------------------------------------------------------===//
21+
// Binary ops shape infer
22+
//===----------------------------------------------------------------------===//
23+
24+
#define BINARY_OP_SHAPE_INFER(OP) \
25+
LogicalResult OP::inferReturnTypeComponents( \
26+
MLIRContext *context, ::std::optional<Location> location, \
27+
OP::Adaptor adaptor, \
28+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
29+
auto inputTy0 = dyn_cast<ShapedType>(adaptor.getInputA().getType()); \
30+
auto inputTy1 = dyn_cast<ShapedType>(adaptor.getInputB().getType()); \
31+
if (!adaptor.getAutoBroadcast() && (inputTy0 != inputTy1)) { \
32+
return failure(); \
33+
} \
34+
llvm::SmallVector<int64_t> outShape; \
35+
auto ret = OpTrait::util::getBroadcastedShape( \
36+
inputTy0.getShape(), inputTy1.getShape(), outShape); \
37+
inferredReturnShapes.push_back( \
38+
ShapedTypeComponents(outShape, inputTy0.getElementType())); \
39+
return LogicalResult::success(ret); \
40+
}
41+
42+
BINARY_OP_SHAPE_INFER(onednn_graph::AddOp)
43+
BINARY_OP_SHAPE_INFER(onednn_graph::MulOp)
44+
BINARY_OP_SHAPE_INFER(onednn_graph::SubOp)
45+
BINARY_OP_SHAPE_INFER(onednn_graph::DivOp)
46+
47+
//===----------------------------------------------------------------------===//
48+
// Reduce ops shape infer
49+
//===----------------------------------------------------------------------===//
50+
51+
SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
52+
int64_t rank) {
53+
SmallVector<int64_t> ret(axes.size());
54+
for (size_t i = 0; i < axes.size(); i++) {
55+
ret[i] = axes[i] < 0 ? axes[i] + rank : axes[i];
56+
}
57+
llvm::sort(ret);
58+
ret.erase(std::unique(ret.begin(), ret.end()), ret.end());
59+
return ret;
60+
}
61+
62+
SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
63+
ArrayRef<int64_t> axes, bool keep_dims) {
64+
SmallVector<int64_t> outputShape;
65+
// get reduce axis one by one
66+
size_t index = 0;
67+
auto getNextReduceAxis = [&]() {
68+
return (index >= axes.size()) ? -1 : axes[index++];
2969
};
70+
// get reduced shape
71+
auto rank = operandShape.getRank();
72+
auto axis = getNextReduceAxis();
73+
for (int64_t idx = 0; idx < rank; idx++) {
74+
if (idx == axis) {
75+
axis = getNextReduceAxis();
76+
if (keep_dims) {
77+
outputShape.push_back(1);
78+
}
79+
} else {
80+
outputShape.push_back(operandShape.getDimSize(idx));
81+
}
82+
}
83+
return outputShape;
84+
}
3085

31-
auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1),
32-
outShape);
33-
inferredReturnShapes.push_back(
34-
ShapedTypeComponents(outShape, resultTy.getElementType()));
35-
return LogicalResult::success(ret);
86+
static LogicalResult InferReduceReturnTypes(
87+
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes,
88+
bool keep_dims,
89+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
90+
// no reduce axes
91+
if (axes.empty()) {
92+
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape));
93+
return success();
94+
}
95+
inferredReturnShapes.push_back(ShapedTypeComponents(
96+
getReducedShape(operandShape, axes, keep_dims), elemType));
97+
return success();
3698
}
3799

100+
template <typename ReduceOp>
101+
struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
102+
using OpRewritePattern<ReduceOp>::OpRewritePattern;
103+
LogicalResult matchAndRewrite(ReduceOp op,
104+
PatternRewriter &rewriter) const override {
105+
auto rank = dyn_cast<ShapedType>(op.getOperand().getType()).getRank();
106+
// consider canonicalized if all axes are non-negative in ascending order
107+
// Note: disable tidy here due to dangling reference in OperationState
108+
// NOLINTBEGIN
109+
bool canonicalized = true;
110+
int64_t last = -1;
111+
for (const auto axis : op.getAxes()) {
112+
if (axis <= last) {
113+
canonicalized = false;
114+
break;
115+
}
116+
last = axis;
117+
}
118+
if (canonicalized) {
119+
return failure();
120+
}
121+
// canonicalize the reduce axes
122+
auto new_axes = canonicalizeReduceAxes(op.getAxes(), rank);
123+
auto new_op = rewriter.create<ReduceOp>(
124+
op.getLoc(), op.getType(), op.getOperand(), new_axes, op.getKeepDims());
125+
rewriter.replaceOp(op, new_op);
126+
// NOLINTEND
127+
return success();
128+
}
129+
};
130+
131+
#define REDUCE_OP_SHAPE_CANONICALIZE(OP) \
132+
void OP::getCanonicalizationPatterns(RewritePatternSet &results, \
133+
MLIRContext *context) { \
134+
using CanonicalizeOp = CanonicalizeReduceOp<OP>; \
135+
results.add<CanonicalizeOp>(context); \
136+
}
137+
138+
#define REDUCE_OP_SHAPE_INFER(OP) \
139+
LogicalResult OP::inferReturnTypeComponents( \
140+
MLIRContext *context, ::std::optional<Location> location, \
141+
OP::Adaptor adaptor, \
142+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
143+
llvm::SmallVector<int64_t> outShape; \
144+
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \
145+
auto rank = operandTy.getRank(); \
146+
ShapeAdaptor inputShape(operandTy); \
147+
return InferReduceReturnTypes( \
148+
inputShape, operandTy.getElementType(), \
149+
canonicalizeReduceAxes(adaptor.getAxes(), rank), \
150+
adaptor.getKeepDims(), inferredReturnShapes); \
151+
}
152+
153+
#define REDUCE_OP_VERIFY(OP) \
154+
LogicalResult OP::verify() { \
155+
auto operandTy = dyn_cast<ShapedType>(getOperand().getType()); \
156+
if (!operandTy.hasRank()) { \
157+
return emitOpError("Invalid operand shape!\n"); \
158+
} \
159+
int64_t rank = operandTy.getRank(); \
160+
for (const auto axis : canonicalizeReduceAxes(getAxes(), rank)) { \
161+
if (axis >= rank || axis < 0) { \
162+
return emitOpError("Reduce axis not valid!\n"); \
163+
} \
164+
} \
165+
return success(); \
166+
}
167+
168+
#define REDUCE_OP_DEFINE(OP) \
169+
REDUCE_OP_SHAPE_CANONICALIZE(OP) \
170+
REDUCE_OP_SHAPE_INFER(OP) \
171+
REDUCE_OP_VERIFY(OP)
172+
173+
REDUCE_OP_DEFINE(onednn_graph::ReduceSumOp)
174+
REDUCE_OP_DEFINE(onednn_graph::ReduceMeanOp)
175+
176+
//===----------------------------------------------------------------------===//
177+
// Matmul ops shape infer
178+
//===----------------------------------------------------------------------===//
179+
38180
LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
39181
MLIRContext *context, ::std::optional<Location> location,
40182
MatMulOp::Adaptor adaptor,
@@ -44,6 +186,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
44186
const ShapeAdaptor &rhsShape, int64_t range,
45187
int64_t diff, SmallVector<int64_t> &outDims) {
46188
for (int64_t i = 0; i < range; i++) {
189+
// TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
47190
int64_t idx = i - diff;
48191
if ((idx >= 0) && (lhsShape.getDimSize(i) != rhsShape.getDimSize(idx))) {
49192
return failure();
@@ -134,7 +277,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
134277
SmallVector<int64_t> resultShape;
135278
if (!biasRankMatch ||
136279
!OpTrait::util::getBroadcastedShape(
137-
retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(),
280+
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(),
138281
resultShape)) {
139282
return failure();
140283
}

0 commit comments

Comments
 (0)