@@ -75,28 +75,28 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
75
75
}
76
76
77
77
// / Common code to create the reshape op where necessary to make the rank of the
78
- // / operations equal. Returns the updated input1 and input2 for the original
79
- // / input . The caller is expected to use these to rewrite the original operator
80
- // / with the RESHAPE now in the graph.
78
+ // / operations equal. input1 and input2 will be updated when the rank has
79
+ // / changed . The caller is expected to use these to rewrite the original
80
+ // / operator with the RESHAPE now in the graph.
81
81
static LogicalResult reshapeLowerToHigher (PatternRewriter &rewriter,
82
82
Location loc,
83
83
RankedTensorType outputType,
84
- Value input1, Value input2,
85
- Value &outInput1, Value &outInput2) {
84
+ Value &input1, Value &input2) {
86
85
auto input1Ty = input1.getType ().dyn_cast <RankedTensorType>();
87
86
auto input2Ty = input2.getType ().dyn_cast <RankedTensorType>();
88
87
89
- if (!input1Ty || !input2Ty)
90
- return failure ();
88
+ if (!input1Ty || !input2Ty) {
89
+ return rewriter.notifyMatchFailure (loc, " input not a ranked tensor" );
90
+ }
91
91
92
92
int64_t input1Rank = input1Ty.getRank ();
93
93
int64_t input2Rank = input2Ty.getRank ();
94
94
95
- Value higherTensorValue, lowerTensorValue;
96
- // Cannot rewrite as its already correct.
97
95
if (input1Rank == input2Rank)
98
- return failure ();
96
+ return rewriter.notifyMatchFailure (loc,
97
+ " cannot rewrite as its already correct" );
99
98
99
+ Value higherTensorValue, lowerTensorValue;
100
100
if (input1Rank > input2Rank) {
101
101
higherTensorValue = input1;
102
102
lowerTensorValue = input2;
@@ -107,15 +107,14 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
107
107
108
108
ArrayRef<int64_t > higherRankShape =
109
109
higherTensorValue.getType ().cast <RankedTensorType>().getShape ();
110
- (void )higherRankShape;
111
110
ArrayRef<int64_t > lowerRankShape =
112
111
lowerTensorValue.getType ().cast <RankedTensorType>().getShape ();
113
112
114
113
SmallVector<int64_t , 4 > reshapeOutputShape;
115
114
116
115
if (computeReshapeOutput (higherRankShape, lowerRankShape, reshapeOutputShape)
117
116
.failed ())
118
- return failure ( );
117
+ return rewriter. notifyMatchFailure (loc, " fail to compute a reshape type " );
119
118
120
119
auto reshapeInputType = lowerTensorValue.getType ().cast <RankedTensorType>();
121
120
auto reshapeOutputType = RankedTensorType::get (
@@ -125,26 +124,28 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
125
124
if (outputType) {
126
125
if (outputType.getShape ().size () != reshapeOutputShape.size () ||
127
126
outputType.getShape ().size () != higherRankShape.size ())
128
- return failure ();
127
+ return rewriter.notifyMatchFailure (
128
+ loc, " the reshaped type doesn't agrees with the ranked output type" );
129
129
}
130
130
131
131
auto reshapeLower = rewriter.create <tosa::ReshapeOp>(
132
132
loc, reshapeOutputType, lowerTensorValue,
133
133
rewriter.getDenseI64ArrayAttr (reshapeOutputShape));
134
134
135
135
if (input1Rank > input2Rank) {
136
- outInput1 = higherTensorValue;
137
- outInput2 = reshapeLower.getResult ();
136
+ input1 = higherTensorValue;
137
+ input2 = reshapeLower.getResult ();
138
138
} else {
139
- outInput1 = reshapeLower.getResult ();
140
- outInput2 = higherTensorValue;
139
+ input1 = reshapeLower.getResult ();
140
+ input2 = higherTensorValue;
141
141
}
142
142
143
143
return success ();
144
144
}
145
145
146
146
namespace {
147
- template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern <OpTy> {
147
+ template <typename OpTy>
148
+ struct ConvertTosaOp : public OpRewritePattern <OpTy> {
148
149
using OpRewritePattern<OpTy>::OpRewritePattern;
149
150
150
151
LogicalResult matchAndRewrite (OpTy tosaBinaryOp,
@@ -158,14 +159,12 @@ template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
158
159
if (!outputType)
159
160
return failure ();
160
161
161
- Value outInput1, outInput2;
162
162
if (reshapeLowerToHigher (rewriter, tosaBinaryOp.getLoc (), outputType,
163
- input1, input2, outInput1, outInput2 )
163
+ input1, input2)
164
164
.failed ())
165
165
return failure ();
166
166
167
- rewriter.replaceOpWithNewOp <OpTy>(tosaBinaryOp, outputType, outInput1,
168
- outInput2);
167
+ rewriter.replaceOpWithNewOp <OpTy>(tosaBinaryOp, outputType, input1, input2);
169
168
170
169
return success ();
171
170
}
@@ -188,14 +187,13 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
188
187
if (!outputType)
189
188
return failure ();
190
189
191
- Value outInput1, outInput2;
192
190
if (reshapeLowerToHigher (rewriter, tosaBinaryOp.getLoc (), outputType,
193
- input1, input2, outInput1, outInput2 )
191
+ input1, input2)
194
192
.failed ())
195
193
return failure ();
196
194
197
- rewriter.replaceOpWithNewOp <tosa::MulOp>(tosaBinaryOp, outputType,
198
- outInput1, outInput2 , shift);
195
+ rewriter.replaceOpWithNewOp <tosa::MulOp>(tosaBinaryOp, outputType, input1,
196
+ input2 , shift);
199
197
200
198
return success ();
201
199
}
@@ -220,14 +218,63 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
220
218
if (!outputType)
221
219
return failure ();
222
220
223
- Value outInput1, outInput2;
224
221
if (reshapeLowerToHigher (rewriter, tosaBinaryOp.getLoc (), outputType,
225
- input1, input2, outInput1, outInput2 )
222
+ input1, input2)
226
223
.failed ())
227
224
return failure ();
228
225
229
226
rewriter.replaceOpWithNewOp <tosa::ArithmeticRightShiftOp>(
230
- tosaBinaryOp, outputType, outInput1, outInput2, round);
227
+ tosaBinaryOp, outputType, input1, input2, round);
228
+
229
+ return success ();
230
+ }
231
+ };
232
+
233
+ template <>
234
+ struct ConvertTosaOp <tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
235
+ using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
236
+
237
+ LogicalResult matchAndRewrite (tosa::SelectOp tosaOp,
238
+ PatternRewriter &rewriter) const override {
239
+
240
+ Value input1 = tosaOp.getPred ();
241
+ Value input2 = tosaOp.getOnTrue ();
242
+ Value input3 = tosaOp.getOnFalse ();
243
+ Value output = tosaOp.getResult ();
244
+
245
+ auto outputType = output.getType ().dyn_cast <RankedTensorType>();
246
+ if (!outputType)
247
+ return rewriter.notifyMatchFailure (tosaOp, " output not a ranked tensor" );
248
+
249
+ // Apply broadcasting to each pair of inputs separately, and chain them as
250
+ // compound as below so that the broadcasting happens all at once.
251
+ bool reshaped1 = reshapeLowerToHigher (rewriter, tosaOp.getLoc (), outputType,
252
+ input1, input2)
253
+ .succeeded ();
254
+
255
+ bool reshaped2 = reshapeLowerToHigher (rewriter, tosaOp.getLoc (), outputType,
256
+ input1, input3)
257
+ .succeeded ();
258
+
259
+ bool reshaped3 = reshapeLowerToHigher (rewriter, tosaOp.getLoc (), outputType,
260
+ input2, input3)
261
+ .succeeded ();
262
+
263
+ if (!reshaped1 && !reshaped2 && !reshaped3)
264
+ return rewriter.notifyMatchFailure (
265
+ tosaOp,
266
+ " cannot rewrite as the rank of all operands is already aligned" );
267
+
268
+ int32_t result1Rank = input1.getType ().cast <RankedTensorType>().getRank ();
269
+ int32_t result2Rank = input2.getType ().cast <RankedTensorType>().getRank ();
270
+ int32_t result3Rank = input3.getType ().cast <RankedTensorType>().getRank ();
271
+
272
+ if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
273
+ return rewriter.notifyMatchFailure (
274
+ tosaOp, " not all ranks are aligned with each other" );
275
+
276
+ rewriter.replaceOpWithNewOp <tosa::SelectOp>(tosaOp, outputType, input1,
277
+ input2, input3);
231
278
232
279
return success ();
233
280
}
@@ -263,6 +310,7 @@ struct TosaMakeBroadcastable
263
310
patterns.add <ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
264
311
patterns.add <ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
265
312
patterns.add <ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
313
+ patterns.add <ConvertTosaOp<tosa::SelectOp>>(ctx);
266
314
patterns.add <ConvertTosaOp<tosa::PowOp>>(ctx);
267
315
(void )applyPatternsAndFoldGreedily (func, std::move (patterns));
268
316
}
0 commit comments