@@ -48,6 +48,8 @@ BINARY_OP_SHAPE_INFER(onednn_graph::DivOp)
48
48
// Reduce ops shape infer
49
49
// ===----------------------------------------------------------------------===//
50
50
51
+ // canonicalize reduced axes
52
+ // make all axis inside reduced axes array non-negative in acceding order
51
53
SmallVector<int64_t > canonicalizeReduceAxes (ArrayRef<int64_t > axes,
52
54
int64_t rank) {
53
55
SmallVector<int64_t > ret (axes.size ());
@@ -59,21 +61,46 @@ SmallVector<int64_t> canonicalizeReduceAxes(ArrayRef<int64_t> axes,
59
61
return ret;
60
62
}
61
63
62
- SmallVector<int64_t > getReducedShape (ShapeAdaptor operandShape,
63
- ArrayRef<int64_t > axes, bool keep_dims) {
64
- SmallVector<int64_t > outputShape;
64
+ // canonicalize kept axes
65
+ // make all axis inside kept axes array non-negative in acceding order
66
+ SmallVector<int64_t > canonicalizeKeepAxes (ArrayRef<int64_t > axes, int64_t rank,
67
+ bool canonicalized = false ) {
68
+ // get canonicalized reduce axes
69
+ auto newCanonicalized = canonicalized ? SmallVector<int64_t >{}
70
+ : canonicalizeReduceAxes (axes, rank);
71
+ auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t >(newCanonicalized);
72
+ // get kept axes
73
+ SmallVector<int64_t > keepAxes;
74
+ for (int64_t dim = 0 , idx = 0 ; dim < rank; dim++) {
75
+ if (idx < (int64_t )reduceAxes.size () && reduceAxes[idx] == dim) {
76
+ idx++;
77
+ continue ;
78
+ }
79
+ keepAxes.push_back (dim);
80
+ }
81
+ return keepAxes;
82
+ }
83
+
84
+ SmallVector<int64_t > inferReducedShape (ShapedType operandShape,
85
+ ArrayRef<int64_t > axes, bool keepDims,
86
+ bool canonicalized = false ) {
87
+ // get canonicalized reduce axes
88
+ auto rank = operandShape.getRank ();
89
+ auto newCanonicalized = canonicalized ? SmallVector<int64_t >{}
90
+ : canonicalizeReduceAxes (axes, rank);
91
+ auto reduceAxes = canonicalized ? axes : ArrayRef<int64_t >(newCanonicalized);
65
92
// get reduce axis one by one
66
93
size_t index = 0 ;
67
94
auto getNextReduceAxis = [&]() {
68
- return (index >= axes .size ()) ? -1 : axes [index++];
95
+ return (index >= reduceAxes .size ()) ? -1 : reduceAxes [index++];
69
96
};
70
97
// get reduced shape
71
- auto rank = operandShape. getRank () ;
98
+ SmallVector< int64_t > outputShape ;
72
99
auto axis = getNextReduceAxis ();
73
100
for (int64_t idx = 0 ; idx < rank; idx++) {
74
101
if (idx == axis) {
75
102
axis = getNextReduceAxis ();
76
- if (keep_dims ) {
103
+ if (keepDims ) {
77
104
outputShape.push_back (1 );
78
105
}
79
106
} else {
@@ -84,16 +111,16 @@ SmallVector<int64_t> getReducedShape(ShapeAdaptor operandShape,
84
111
}
85
112
86
113
static LogicalResult InferReduceReturnTypes (
87
- ShapeAdaptor operandShape, Type elemType, ArrayRef<int64_t > axes,
88
- bool keep_dims,
114
+ ShapedType operandTy, ArrayRef<int64_t > axes, bool keepDims,
89
115
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
90
116
// no reduce axes
91
117
if (axes.empty ()) {
92
- inferredReturnShapes.push_back (ShapedTypeComponents (operandShape ));
118
+ inferredReturnShapes.push_back (ShapedTypeComponents (operandTy ));
93
119
return success ();
94
120
}
95
- inferredReturnShapes.push_back (ShapedTypeComponents (
96
- getReducedShape (operandShape, axes, keep_dims), elemType));
121
+ inferredReturnShapes.push_back (
122
+ ShapedTypeComponents (inferReducedShape (operandTy, axes, keepDims),
123
+ operandTy.getElementType ()));
97
124
return success ();
98
125
}
99
126
@@ -119,10 +146,10 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
119
146
return failure ();
120
147
}
121
148
// canonicalize the reduce axes
122
- auto new_axes = canonicalizeReduceAxes (op.getAxes (), rank);
123
- auto new_op = rewriter.create <ReduceOp>(
124
- op.getLoc (), op.getType (), op.getOperand (), new_axes , op.getKeepDims ());
125
- rewriter.replaceOp (op, new_op );
149
+ auto newAxes = canonicalizeReduceAxes (op.getAxes (), rank);
150
+ auto newOp = rewriter.create <ReduceOp>(
151
+ op.getLoc (), op.getType (), op.getOperand (), newAxes , op.getKeepDims ());
152
+ rewriter.replaceOp (op, newOp );
126
153
// NOLINTEND
127
154
return success ();
128
155
}
@@ -142,12 +169,9 @@ struct CanonicalizeReduceOp : public OpRewritePattern<ReduceOp> {
142
169
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
143
170
llvm::SmallVector<int64_t > outShape; \
144
171
auto operandTy = dyn_cast<ShapedType>(adaptor.getOperand ().getType ()); \
145
- auto rank = operandTy.getRank (); \
146
- ShapeAdaptor inputShape (operandTy); \
147
- return InferReduceReturnTypes ( \
148
- inputShape, operandTy.getElementType (), \
149
- canonicalizeReduceAxes (adaptor.getAxes (), rank), \
150
- adaptor.getKeepDims (), inferredReturnShapes); \
172
+ return InferReduceReturnTypes (operandTy, adaptor.getAxes (), \
173
+ adaptor.getKeepDims (), \
174
+ inferredReturnShapes); \
151
175
}
152
176
153
177
#define REDUCE_OP_VERIFY (OP ) \
@@ -181,22 +205,37 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
181
205
MLIRContext *context, ::std::optional<Location> location,
182
206
MatMulOp::Adaptor adaptor,
183
207
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
184
- // get batch dims from shape
185
- auto extractBatch = [](const ShapeAdaptor &lhsShape,
186
- const ShapeAdaptor &rhsShape, int64_t range,
187
- int64_t diff, SmallVector<int64_t > &outDims) {
188
- for (int64_t i = 0 ; i < range; i++) {
189
- // TODO(longsheng): add OpTrait::util::getBroadcastedShape for batch
190
- int64_t idx = i - diff;
191
- if ((idx >= 0 ) && (lhsShape.getDimSize (i) != rhsShape.getDimSize (idx))) {
192
- return failure ();
193
- }
194
- outDims.push_back (lhsShape.getDimSize (i));
208
+ // get batch dims from 1 multi-batch mat shape
209
+ auto extractBatch = [](ShapedType shape, SmallVector<int64_t > &outDims) {
210
+ // assuming last 2 input dims are row and col
211
+ assert (shape.getRank () >= 2 );
212
+ for (int64_t i = 0 ; i < shape.getRank () - 2 ; i++) {
213
+ outDims.push_back (shape.getDimSize (i));
195
214
}
196
215
return success ();
197
216
};
217
+ // get broadcasted batch dims from 2 multi-batch mat shape,
218
+ auto extractBroadcastBatch = [](ShapedType lhsType, ShapedType rhsType,
219
+ SmallVector<int64_t > &outDims) {
220
+ SmallVector<int64_t > lhsShape (lhsType.getShape ());
221
+ SmallVector<int64_t > rhsShape (rhsType.getShape ());
222
+ assert (lhsShape.size () >= 2 && rhsShape.size () >= 2 );
223
+ // assuming last 2 input dims are row and col
224
+ // 0xFF is just a random number > 1, replacing the row and col dims
225
+ // so that getBroadcastedShape can match, will be removed after
226
+ lhsShape[lhsShape.size () - 1 ] = 0xFF ;
227
+ lhsShape[lhsShape.size () - 2 ] = 0xFF ;
228
+ rhsShape[rhsShape.size () - 1 ] = 0xFF ;
229
+ rhsShape[rhsShape.size () - 2 ] = 0xFF ;
230
+ bool ret = OpTrait::util::getBroadcastedShape (lhsShape, rhsShape, outDims);
231
+ // remove 0xFF
232
+ assert (outDims.size () >= 2 );
233
+ outDims.pop_back ();
234
+ outDims.pop_back ();
235
+ return LogicalResult::success (ret);
236
+ };
198
237
// get row col of 2d matrix according to transpose info
199
- auto getMatRowCol = [](const ShapeAdaptor & shape, bool transpose) {
238
+ auto getMatRowCol = [](ShapedType shape, bool transpose) {
200
239
using pairRowCol = std::pair<int64_t , int64_t >;
201
240
auto rank = shape.getRank ();
202
241
assert (rank > 1 );
@@ -205,8 +244,8 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
205
244
: pairRowCol (shape.getDimSize (rank - 2 ),
206
245
shape.getDimSize (rank - 1 ));
207
246
};
208
- ShapeAdaptor lhsShape (adaptor.getInputA ().getType ());
209
- ShapeAdaptor rhsShape (adaptor.getInputB ().getType ());
247
+ auto lhsShape = cast<ShapedType> (adaptor.getInputA ().getType ());
248
+ auto rhsShape = cast<ShapedType> (adaptor.getInputB ().getType ());
210
249
bool transposeA = adaptor.getTransposeA ();
211
250
bool transposeB = adaptor.getTransposeB ();
212
251
int64_t lRank = lhsShape.getRank ();
@@ -223,36 +262,24 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
223
262
} else if (lRank == 1 && rRank > 1 ) {
224
263
// 1D x ND
225
264
auto rMatRowCol = getMatRowCol (rhsShape, transposeB);
226
- status = extractBatch (rhsShape, rhsShape, rRank - 2 , 0 , outShape);
265
+ status = extractBatch (rhsShape, outShape);
227
266
if (lhsShape.getDimSize (0 ) != rMatRowCol.first ) {
228
267
return failure ();
229
268
}
230
269
outShape.push_back (rhsShape.getDimSize (rMatRowCol.second ));
231
270
} else if (lRank > 1 && rRank == 1 ) {
232
271
// ND x 1D
233
272
auto lMatRowCol = getMatRowCol (lhsShape, transposeA);
234
- status = extractBatch (lhsShape, lhsShape, lRank - 2 , 0 , outShape);
273
+ status = extractBatch (lhsShape, outShape);
235
274
if (lMatRowCol.second != rhsShape.getDimSize (0 )) {
236
275
return failure ();
237
276
}
238
277
outShape.push_back (lhsShape.getDimSize (lMatRowCol.first ));
239
278
} else if (lRank > 1 && rRank > 1 ) {
240
- if (lRank == rRank) {
241
- // ND x ND
242
- auto range = lRank - 2 ;
243
- status = extractBatch (lhsShape, rhsShape, range, 0 , outShape);
244
- } else if (lRank > rRank) {
245
- // MD x ND (M > N)
246
- auto range = lRank - 2 ;
247
- auto diff = lRank - rRank;
248
- status = extractBatch (lhsShape, rhsShape, range, diff, outShape);
249
- } else {
250
- // ND x MD (M > N)
251
- auto range = rRank - 2 ;
252
- auto diff = rRank - lRank;
253
- status = extractBatch (rhsShape, lhsShape, range, diff, outShape);
254
- }
255
- //
279
+ // ND x ND
280
+ // MD x ND (M > N)
281
+ // ND x MD (M > N)
282
+ status = extractBroadcastBatch (lhsShape, rhsShape, outShape);
256
283
auto lMatRowCol = getMatRowCol (lhsShape, transposeA);
257
284
auto rMatRowCol = getMatRowCol (rhsShape, transposeB);
258
285
if (failed (status) || (lMatRowCol.second != rMatRowCol.first )) {
@@ -269,16 +296,13 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
269
296
inferredReturnShapes.push_back (retShape);
270
297
// check for bias broadcasting
271
298
if (adaptor.getBias ()) {
272
- auto biasType = adaptor.getBias ().getType ();
273
- ShapeAdaptor biasShape (biasType);
274
-
275
- bool biasRankMatch = biasShape.getRank () == 1 ||
276
- biasShape.getRank () == (int64_t )outShape.size ();
299
+ auto biasType = dyn_cast<ShapedType>(adaptor.getBias ().getType ());
300
+ bool biasRankMatch = biasType.getRank () == 1 ||
301
+ biasType.getRank () == (int64_t )outShape.size ();
277
302
SmallVector<int64_t > resultShape;
278
303
if (!biasRankMatch ||
279
- !OpTrait::util::getBroadcastedShape (
280
- retShape.getDims (), dyn_cast<ShapedType>(biasType).getShape (),
281
- resultShape)) {
304
+ !OpTrait::util::getBroadcastedShape (retShape.getDims (),
305
+ biasType.getShape (), resultShape)) {
282
306
return failure ();
283
307
}
284
308
}
0 commit comments