17
17
namespace mlir {
18
18
namespace onednn_graph {
19
19
20
- LogicalResult onednn_graph::AddOp::inferReturnTypeComponents (
21
- MLIRContext *context, ::std::optional<Location> location,
22
- ValueShapeRange operands, DictionaryAttr attributes,
23
- OpaqueProperties properties, RegionRange regions,
24
- SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
25
- llvm::SmallVector<int64_t > outShape;
26
- auto resultTy = dyn_cast<ShapedType>(operands.front ().getType ());
27
- auto getShapeIdx = [&operands](size_t i) {
28
- return operands.getTypes ()[i].dyn_cast <ShapedType>().getShape ();
20
+ // ===----------------------------------------------------------------------===//
21
+ // Binary ops shape infer
22
+ // ===----------------------------------------------------------------------===//
23
+
24
+ #define BINARY_OP_SHAPE_INFER (OP ) \
25
+ LogicalResult OP::inferReturnTypeComponents ( \
26
+ MLIRContext *context, ::std::optional<Location> location, \
27
+ OP::Adaptor adaptor, \
28
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
29
+ auto inputTy0 = dyn_cast<ShapedType>(adaptor.getInputA ().getType ()); \
30
+ auto inputTy1 = dyn_cast<ShapedType>(adaptor.getInputB ().getType ()); \
31
+ if (!adaptor.getAutoBroadcast () && (inputTy0 != inputTy1)) { \
32
+ return failure (); \
33
+ } \
34
+ llvm::SmallVector<int64_t > outShape; \
35
+ auto ret = OpTrait::util::getBroadcastedShape ( \
36
+ inputTy0.getShape (), inputTy1.getShape (), outShape); \
37
+ inferredReturnShapes.push_back ( \
38
+ ShapedTypeComponents (outShape, inputTy0.getElementType ())); \
39
+ return LogicalResult::success (ret); \
40
+ }
41
+
42
+ BINARY_OP_SHAPE_INFER (onednn_graph::AddOp)
43
+ BINARY_OP_SHAPE_INFER (onednn_graph::MulOp)
44
+ BINARY_OP_SHAPE_INFER (onednn_graph::SubOp)
45
+ BINARY_OP_SHAPE_INFER (onednn_graph::DivOp)
46
+
47
+ // ===----------------------------------------------------------------------===//
48
+ // Reduce ops shape infer
49
+ // ===----------------------------------------------------------------------===//
50
+
51
+ SmallVector<int64_t > canonicalizeReduceAxes (ArrayRef<int64_t > axes,
52
+ int64_t rank) {
53
+ SmallVector<int64_t > ret (axes.size ());
54
+ for (size_t i = 0 ; i < axes.size (); i++) {
55
+ ret[i] = axes[i] < 0 ? axes[i] + rank : axes[i];
56
+ }
57
+ llvm::sort (ret);
58
+ ret.erase (std::unique (ret.begin (), ret.end ()), ret.end ());
59
+ return ret;
60
+ }
61
+
62
+ SmallVector<int64_t > getReducedShape (ShapeAdaptor operandShape,
63
+ ArrayRef<int64_t > axes, bool keep_dims) {
64
+ SmallVector<int64_t > outputShape;
65
+ // get reduce axis one by one
66
+ size_t index = 0 ;
67
+ auto getNextReduceAxis = [&]() {
68
+ return (index >= axes.size ()) ? -1 : axes[index++];
29
69
};
70
+ // get reduced shape
71
+ auto rank = operandShape.getRank ();
72
+ auto axis = getNextReduceAxis ();
73
+ for (int64_t idx = 0 ; idx < rank; idx++) {
74
+ if (idx == axis) {
75
+ axis = getNextReduceAxis ();
76
+ if (keep_dims) {
77
+ outputShape.push_back (1 );
78
+ }
79
+ } else {
80
+ outputShape.push_back (operandShape.getDimSize (idx));
81
+ }
82
+ }
83
+ return outputShape;
84
+ }
30
85
31
- auto ret = OpTrait::util::getBroadcastedShape (getShapeIdx (0 ), getShapeIdx (1 ),
32
- outShape);
33
- inferredReturnShapes.push_back (
34
- ShapedTypeComponents (outShape, resultTy.getElementType ()));
35
- return LogicalResult::success (ret);
86
+ static LogicalResult InferReduceReturnTypes (
87
+ ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t > axes,
88
+ bool keep_dims,
89
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
90
+ // no reduce axes
91
+ if (axes.empty ()) {
92
+ inferredReturnShapes.push_back (ShapedTypeComponents (operandShape));
93
+ return success ();
94
+ }
95
+ inferredReturnShapes.push_back (ShapedTypeComponents (
96
+ getReducedShape (operandShape, axes, keep_dims), elemType));
97
+ return success ();
36
98
}
37
99
100
+ template <typename ReduceOp>
101
+ struct CanonicalizeReduceOp : public OpRewritePattern <ReduceOp> {
102
+ using OpRewritePattern<ReduceOp>::OpRewritePattern;
103
+ LogicalResult matchAndRewrite (ReduceOp op,
104
+ PatternRewriter &rewriter) const override {
105
+ auto rank = dyn_cast<ShapedType>(op.getOperand ().getType ()).getRank ();
106
+ // consider canonicalized if all axes are non-negative in ascending order
107
+ // Note: disable tidy here due to dangling reference in OperationState
108
+ // NOLINTBEGIN
109
+ bool canonicalized = true ;
110
+ int64_t last = -1 ;
111
+ for (const auto axis : op.getAxes ()) {
112
+ if (axis <= last) {
113
+ canonicalized = false ;
114
+ break ;
115
+ }
116
+ last = axis;
117
+ }
118
+ if (canonicalized) {
119
+ return failure ();
120
+ }
121
+ // canonicalize the reduce axes
122
+ auto new_axes = canonicalizeReduceAxes (op.getAxes (), rank);
123
+ auto new_op = rewriter.create <ReduceOp>(
124
+ op.getLoc (), op.getType (), op.getOperand (), new_axes, op.getKeepDims ());
125
+ rewriter.replaceOp (op, new_op);
126
+ // NOLINTEND
127
+ return success ();
128
+ }
129
+ };
130
+
131
+ #define REDUCE_OP_SHAPE_CANONICALIZE (OP ) \
132
+ void OP::getCanonicalizationPatterns (RewritePatternSet &results, \
133
+ MLIRContext *context) { \
134
+ using CanonicalizeOp = CanonicalizeReduceOp<OP>; \
135
+ results.add <CanonicalizeOp>(context); \
136
+ }
137
+
138
+ #define REDUCE_OP_SHAPE_INFER (OP ) \
139
+ LogicalResult OP::inferReturnTypeComponents ( \
140
+ MLIRContext *context, ::std::optional<Location> location, \
141
+ OP::Adaptor adaptor, \
142
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
143
+ llvm::SmallVector<int64_t > outShape; \
144
+ auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand ().getType ()); \
145
+ auto rank = operandTy.getRank (); \
146
+ ShapeAdaptor inputShape (operandTy); \
147
+ return InferReduceReturnTypes ( \
148
+ inputShape, operandTy.getElementType (), \
149
+ canonicalizeReduceAxes (adaptor.getAxes (), rank), \
150
+ adaptor.getKeepDims (), inferredReturnShapes); \
151
+ }
152
+
153
+ #define REDUCE_OP_VERIFY (OP ) \
154
+ LogicalResult OP::verify () { \
155
+ auto operandTy = dyn_cast<ShapedType>(getOperand ().getType ()); \
156
+ if (!operandTy.hasRank ()) { \
157
+ return emitOpError (" Invalid operand shape!\n " ); \
158
+ } \
159
+ int64_t rank = operandTy.getRank (); \
160
+ for (const auto axis : canonicalizeReduceAxes (getAxes (), rank)) { \
161
+ if (axis >= rank || axis < 0 ) { \
162
+ return emitOpError (" Reduce axis not valid!\n " ); \
163
+ } \
164
+ } \
165
+ return success (); \
166
+ }
167
+
168
+ #define REDUCE_OP_DEFINE (OP ) \
169
+ REDUCE_OP_SHAPE_CANONICALIZE (OP) \
170
+ REDUCE_OP_SHAPE_INFER (OP) \
171
+ REDUCE_OP_VERIFY (OP)
172
+
173
+ REDUCE_OP_DEFINE (onednn_graph::ReduceSumOp)
174
+ REDUCE_OP_DEFINE (onednn_graph::ReduceMeanOp)
175
+
176
+ // ===----------------------------------------------------------------------===//
177
+ // Matmul ops shape infer
178
+ // ===----------------------------------------------------------------------===//
179
+
38
180
LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents (
39
181
MLIRContext *context, ::std::optional<Location> location,
40
182
MatMulOp::Adaptor adaptor,
@@ -44,6 +186,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
44
186
const ShapeAdaptor &rhsShape, int64_t range,
45
187
int64_t diff, SmallVector<int64_t > &outDims) {
46
188
for (int64_t i = 0 ; i < range; i++) {
189
+ // TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
47
190
int64_t idx = i - diff;
48
191
if ((idx >= 0 ) && (lhsShape.getDimSize (i) != rhsShape.getDimSize (idx))) {
49
192
return failure ();
@@ -134,7 +277,7 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
134
277
SmallVector<int64_t > resultShape;
135
278
if (!biasRankMatch ||
136
279
!OpTrait::util::getBroadcastedShape (
137
- retShape.getDims (), biasType. dyn_cast <ShapedType>().getShape (),
280
+ retShape.getDims (), dyn_cast<ShapedType>(biasType ).getShape (),
138
281
resultShape)) {
139
282
return failure ();
140
283
}
0 commit comments