Skip to content

Commit 0781b4f

Browse files
author
Longsheng Du
committed
rebase
1 parent fce4118 commit 0781b4f

File tree

3 files changed

+379
-27
lines changed

3 files changed

+379
-27
lines changed

include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,38 @@ include "OneDNNGraphTypes.td"
2424
class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
2525
Op<OneDNNGraphDialect, mnemonic, traits>;
2626

27+
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
28+
OneDNNGraph_Op<mnemonic, traits #
29+
[SameOperandsAndResultType]> {
30+
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
31+
let results = (outs OneDNNGraph_FloatTensor:$result);
32+
33+
let assemblyFormat =
34+
"operands attr-dict `:` functional-type(operands, results)";
35+
}
36+
2737
class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
28-
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType,
29-
ResultsBroadcastableShape]> {
30-
let arguments = (ins OneDNNGraph_FloatTensor:$input_0,
31-
OneDNNGraph_FloatTensor:$input_1);
38+
OneDNNGraph_Op<mnemonic, traits #
39+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor, ResultsBroadcastableShape]> {
40+
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
41+
OneDNNGraph_FloatTensor:$input_b,
42+
DefaultValuedOptionalAttr<BoolAttr, "true">:$auto_broadcast);
3243
let results = (outs OneDNNGraph_FloatTensor:$result);
3344

3445
let assemblyFormat =
3546
"operands attr-dict `:` functional-type(operands, results)";
3647
}
3748

38-
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
39-
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
40-
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
49+
class OneDNNGraph_ReduceOp<string mnemonic, list<Trait> traits = []> :
50+
OneDNNGraph_Op<mnemonic, traits #
51+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
52+
let arguments = (ins OneDNNGraph_FloatTensor:$operand,
53+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$axes,
54+
DefaultValuedOptionalAttr<BoolAttr, "false">:$keep_dims);
4155
let results = (outs OneDNNGraph_FloatTensor:$result);
4256

57+
let hasVerifier = 1;
58+
let hasCanonicalizer = 1;
4359
let assemblyFormat =
4460
"operands attr-dict `:` functional-type(operands, results)";
4561
}
@@ -48,36 +64,142 @@ class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
4864
// OneDNNGraph op definitions
4965
//===----------------------------------------------------------------------===//
5066

67+
// Matmul
68+
5169
def OneDNNGraph_MatMulOp :
52-
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
53-
let summary = "Generalized matrix multiplication";
70+
OneDNNGraph_Op<"matmul",
71+
[SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
72+
let summary = [{
73+
MatMul operation computes the product of two tensors with optional bias addition.
74+
}];
5475
let description = [{
5576
`https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html`
5677
}];
5778

5879
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
5980
OneDNNGraph_FloatTensor:$input_b,
6081
Optional<OneDNNGraph_LogicalTensor>:$bias,
61-
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
62-
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
82+
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_a,
83+
DefaultValuedOptionalAttr<BoolAttr, "false">:$transpose_b);
6384
let results = (outs OneDNNGraph_FloatTensor:$result);
6485

6586
let assemblyFormat =
6687
"operands attr-dict `:` functional-type(operands, results)";
6788
}
6889

90+
// Common Unary
91+
6992
def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
7093
let summary = "element-wise relu";
7194
let description = [{
7295
`https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html`
7396
}];
7497
}
7598

99+
def OneDNNGraph_SigmoidOp : OneDNNGraph_ElemwiseUnaryOp<"sigmoid"> {
100+
let summary = "element-wise sigmoid";
101+
let description = [{
102+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_sigmoid.html`
103+
}];
104+
}
105+
106+
// Special Unary
107+
108+
def OneDNNGraph_TypeCastOp : OneDNNGraph_Op<"type_cast", [SameOperandsAndResultShape]> {
109+
let summary = [{
110+
TypeCast operation performs element-wise cast from input data type
111+
to the data type given by output tensor.
112+
}];
113+
let description = [{
114+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_typecast.html`
115+
}];
116+
117+
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
118+
let results = (outs OneDNNGraph_FloatTensor:$result);
119+
120+
let assemblyFormat =
121+
"operands attr-dict `:` functional-type(operands, results)";
122+
}
123+
124+
def OneDNNGraph_PowOp : OneDNNGraph_Op<"pow", [SameOperandsAndResultType]> {
125+
let summary = [{
126+
Pow operation performs an element-wise power operation on a given input
127+
tensor with a single value attribute beta as its exponent.
128+
}];
129+
let description = [{
130+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_pow.html`
131+
}];
132+
133+
let arguments = (ins OneDNNGraph_FloatTensor:$operand,
134+
F32Attr:$beta);
135+
let results = (outs OneDNNGraph_FloatTensor:$result);
136+
137+
let assemblyFormat =
138+
"operands attr-dict `:` functional-type(operands, results)";
139+
}
140+
141+
// Common Binary
142+
76143
def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
77-
let summary = "element-wise addition with multi-directional broadcast";
144+
let summary = [{
145+
Add operation performs element-wise addition operation with two
146+
given tensors applying multi-directional broadcast rules.
147+
}];
78148
let description = [{
79149
`https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html`
80150
}];
81151
}
82152

153+
def OneDNNGraph_MulOp : OneDNNGraph_ElemwiseBinaryOp<"mul", [Commutative]> {
154+
let summary = [{
155+
Multiply operation performs element-wise multiply operation with two
156+
given tensors applying multi-directional broadcast rules.
157+
}];
158+
let description = [{
159+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_multiply.html`
160+
}];
161+
}
162+
163+
def OneDNNGraph_SubOp : OneDNNGraph_ElemwiseBinaryOp<"sub"> {
164+
let summary = [{
165+
Subtract operation performs element-wise subtraction operation with
166+
two given tensors applying multi-directional broadcast rules.
167+
}];
168+
let description = [{
169+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_subtract.html`
170+
}];
171+
}
172+
173+
def OneDNNGraph_DivOp : OneDNNGraph_ElemwiseBinaryOp<"div"> {
174+
let summary = [{
175+
Divide operation performs element-wise division operation with two
176+
given tensors applying multi-directional broadcast rules.
177+
}];
178+
let description = [{
179+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_divide.html`
180+
}];
181+
}
182+
183+
// Common Reduce
184+
185+
def OneDNNGraph_ReduceSumOp : OneDNNGraph_ReduceOp<"reduce_sum"> {
186+
let summary = [{
187+
ReduceSum operation performs the reduction with addition on a given
188+
src data along dimensions specified by axes.
189+
}];
190+
let description = [{
191+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducesum.html`
192+
}];
193+
}
194+
195+
def OneDNNGraph_ReduceMeanOp : OneDNNGraph_ReduceOp<"reduce_mean"> {
196+
let summary = [{
197+
ReduceMean operation performs the reduction with finding the arithmetic
198+
mean on a given src data along dimensions specified by axes.
199+
}];
200+
let description = [{
201+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_reducemean.html`
202+
}];
203+
}
204+
83205
#endif // ONEDNNGRAPH_OPS

lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp

Lines changed: 146 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,155 @@
1717
namespace mlir {
1818
namespace onednn_graph {
1919

20-
LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
21-
MLIRContext *context, ::std::optional<Location> location,
22-
ValueShapeRange operands, DictionaryAttr attributes,
23-
OpaqueProperties properties, RegionRange regions,
20+
//===----------------------------------------------------------------------===//
21+
// Binary ops shape infer
22+
//===----------------------------------------------------------------------===//
23+
24+
#define BINARY_OP_SHAPE_INFER(OP) \
25+
LogicalResult OP::inferReturnTypeComponents( \
26+
MLIRContext *context, ::std::optional<Location> location, \
27+
OP::Adaptor adaptor, \
28+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
29+
auto inputTy0 = dyn_cast<ShapedType>(adaptor.getInputA().getType()); \
30+
auto inputTy1 = dyn_cast<ShapedType>(adaptor.getInputB().getType()); \
31+
if (!adaptor.getAutoBroadcast() && (inputTy0 != inputTy1)) { \
32+
return failure(); \
33+
} \
34+
llvm::SmallVector<int64_t> outShape; \
35+
auto ret = OpTrait::util::getBroadcastedShape( \
36+
inputTy0.getShape(), inputTy1.getShape(), outShape); \
37+
inferredReturnShapes.push_back( \
38+
ShapedTypeComponents(outShape, inputTy0.getElementType())); \
39+
return LogicalResult::success(ret); \
40+
}
41+
42+
BINARY_OP_SHAPE_INFER(onednn_graph::AddOp)
43+
BINARY_OP_SHAPE_INFER(onednn_graph::MulOp)
44+
BINARY_OP_SHAPE_INFER(onednn_graph::SubOp)
45+
BINARY_OP_SHAPE_INFER(onednn_graph::DivOp)
46+
47+
//===----------------------------------------------------------------------===//
48+
// Reduce ops shape infer
49+
//===----------------------------------------------------------------------===//
50+
51+
SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
52+
int64_t rank) {
53+
SmallVector<int64_t> ret(axes.size());
54+
for (size_t i = 0; i < axes.size(); i++) {
55+
ret[i] = axes[i] < 0 ? axes[i] + rank : axes[i];
56+
}
57+
llvm::sort(ret);
58+
ret.erase(std::unique(ret.begin(), ret.end()), ret.end());
59+
return ret;
60+
}
61+
62+
static LogicalResult InferReduceReturnTypes(
63+
ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t> axes,
64+
bool keep_dims,
2465
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
25-
llvm::SmallVector<int64_t> outShape;
26-
auto resultTy = dyn_cast<ShapedType>(operands.front().getType());
27-
auto getShapeIdx = [&operands](size_t i) {
28-
return operands.getTypes()[i].dyn_cast<ShapedType>().getShape();
66+
// no reduce axes
67+
if (axes.empty()) {
68+
inferredReturnShapes.push_back(ShapedTypeComponents(operandShape));
69+
return success();
70+
}
71+
// get reduce axis one by one
72+
size_t index = 0;
73+
auto getNextReduceAxis = [&]() {
74+
return (index >= axes.size()) ? -1 : axes[index++];
2975
};
30-
31-
auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1),
32-
outShape);
33-
inferredReturnShapes.push_back(
34-
ShapedTypeComponents(outShape, resultTy.getElementType()));
35-
return LogicalResult::success(ret);
76+
// get reduced shape
77+
auto rank = operandShape.getRank();
78+
auto axis = getNextReduceAxis();
79+
SmallVector<int64_t> outputShape;
80+
for (int64_t idx = 0; idx < rank; idx++) {
81+
if (idx == axis) {
82+
axis = getNextReduceAxis();
83+
if (keep_dims) {
84+
outputShape.push_back(1);
85+
}
86+
} else {
87+
outputShape.push_back(operandShape.getDimSize(idx));
88+
}
89+
}
90+
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, elemType));
91+
return success();
3692
}
3793

94+
template <typename ReduceOp>
95+
struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
96+
using OpRewritePattern<ReduceOp>::OpRewritePattern;
97+
LogicalResult matchAndRewrite(ReduceOp op,
98+
PatternRewriter &rewriter) const override {
99+
auto rank = dyn_cast<ShapedType>(op.getOperand().getType()).getRank();
100+
// consider canonicalized if all axes are non-negative in ascending order
101+
int64_t last = -1;
102+
bool canonicalized = true;
103+
for (const auto axis : op.getAxes()) {
104+
if (axis <= last) {
105+
canonicalized = false;
106+
break;
107+
}
108+
last = axis;
109+
}
110+
if (canonicalized) {
111+
return failure();
112+
}
113+
// canonicalize the reduce axes
114+
auto axes = canonicalizeReduceAxes(op.getAxes(), rank);
115+
rewriter.replaceOpWithNewOp<ReduceOp>(op, op.getType(), op.getOperand(),
116+
axes, op.getKeepDims());
117+
return success();
118+
}
119+
};
120+
121+
#define REDUCE_OP_SHAPE_CANONICALIZE(OP) \
122+
void OP::getCanonicalizationPatterns(RewritePatternSet &results, \
123+
MLIRContext *context) { \
124+
results.add<CanonicalizeReduceOp<OP>>(context); \
125+
}
126+
127+
#define REDUCE_OP_SHAPE_INFER(OP) \
128+
LogicalResult OP::inferReturnTypeComponents( \
129+
MLIRContext *context, ::std::optional<Location> location, \
130+
OP::Adaptor adaptor, \
131+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
132+
llvm::SmallVector<int64_t> outShape; \
133+
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand().getType()); \
134+
auto rank = operandTy.getRank(); \
135+
ShapeAdaptor inputShape(operandTy); \
136+
return InferReduceReturnTypes( \
137+
inputShape, operandTy.getElementType(), \
138+
canonicalizeReduceAxes(adaptor.getAxes(), rank), \
139+
adaptor.getKeepDims(), inferredReturnShapes); \
140+
}
141+
142+
#define REDUCE_OP_VERIFY(OP) \
143+
LogicalResult OP::verify() { \
144+
auto operandTy = dyn_cast<ShapedType>(getOperand().getType()); \
145+
if (!operandTy.hasRank()) { \
146+
return emitOpError("Invalid operand shape!\n"); \
147+
} \
148+
int64_t rank = operandTy.getRank(); \
149+
for (const auto axis : canonicalizeReduceAxes(getAxes(), rank)) { \
150+
if (axis >= rank || axis < 0) { \
151+
return emitOpError("Reduce axis not valid!\n"); \
152+
} \
153+
} \
154+
return success(); \
155+
}
156+
157+
#define REDUCE_OP_DEFINE(OP) \
158+
REDUCE_OP_SHAPE_CANONICALIZE(OP) \
159+
REDUCE_OP_SHAPE_INFER(OP) \
160+
REDUCE_OP_VERIFY(OP)
161+
162+
REDUCE_OP_DEFINE(onednn_graph::ReduceSumOp)
163+
REDUCE_OP_DEFINE(onednn_graph::ReduceMeanOp)
164+
165+
//===----------------------------------------------------------------------===//
166+
// Matmul ops shape infer
167+
//===----------------------------------------------------------------------===//
168+
38169
LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
39170
MLIRContext *context, ::std::optional<Location> location,
40171
MatMulOp::Adaptor adaptor,
@@ -134,7 +265,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
134265
SmallVector<int64_t> resultShape;
135266
if (!biasRankMatch ||
136267
!OpTrait::util::getBroadcastedShape(
137-
retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(),
268+
retShape.getDims(), dyn_cast<ShapedType>(biasType).getShape(),
138269
resultShape)) {
139270
return failure();
140271
}

0 commit comments

Comments
 (0)