Skip to content

Commit 936819b

Browse files
tatwaichongrsuderman
authored andcommitted
[mlir][tosa] make Select operator broadcastable in the pass
Making Select broadcastable can let this op easier to use. Change-Id: I4a4bec4f7cbe532e954a5b4fe53136676ab4300c Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D139156
1 parent e87cc8a commit 936819b

File tree

2 files changed

+165
-29
lines changed

2 files changed

+165
-29
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,28 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
7575
}
7676

7777
/// Common code to create the reshape op where necessary to make the rank of the
78-
/// operations equal. Returns the updated input1 and input2 for the original
79-
/// input. The caller is expected to use these to rewrite the original operator
80-
/// with the RESHAPE now in the graph.
78+
/// operations equal. input1 and input2 will be updated when the rank has
79+
/// changed. The caller is expected to use these to rewrite the original
80+
/// operator with the RESHAPE now in the graph.
8181
static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
8282
Location loc,
8383
RankedTensorType outputType,
84-
Value input1, Value input2,
85-
Value &outInput1, Value &outInput2) {
84+
Value &input1, Value &input2) {
8685
auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
8786
auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
8887

89-
if (!input1Ty || !input2Ty)
90-
return failure();
88+
if (!input1Ty || !input2Ty) {
89+
return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
90+
}
9191

9292
int64_t input1Rank = input1Ty.getRank();
9393
int64_t input2Rank = input2Ty.getRank();
9494

95-
Value higherTensorValue, lowerTensorValue;
96-
// Cannot rewrite as its already correct.
9795
if (input1Rank == input2Rank)
98-
return failure();
96+
return rewriter.notifyMatchFailure(loc,
97+
"cannot rewrite as its already correct");
9998

99+
Value higherTensorValue, lowerTensorValue;
100100
if (input1Rank > input2Rank) {
101101
higherTensorValue = input1;
102102
lowerTensorValue = input2;
@@ -107,15 +107,14 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
107107

108108
ArrayRef<int64_t> higherRankShape =
109109
higherTensorValue.getType().cast<RankedTensorType>().getShape();
110-
(void)higherRankShape;
111110
ArrayRef<int64_t> lowerRankShape =
112111
lowerTensorValue.getType().cast<RankedTensorType>().getShape();
113112

114113
SmallVector<int64_t, 4> reshapeOutputShape;
115114

116115
if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
117116
.failed())
118-
return failure();
117+
return rewriter.notifyMatchFailure(loc, "fail to compute a reshape type");
119118

120119
auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
121120
auto reshapeOutputType = RankedTensorType::get(
@@ -125,26 +124,28 @@ static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
125124
if (outputType) {
126125
if (outputType.getShape().size() != reshapeOutputShape.size() ||
127126
outputType.getShape().size() != higherRankShape.size())
128-
return failure();
127+
return rewriter.notifyMatchFailure(
128+
loc, "the reshaped type doesn't agrees with the ranked output type");
129129
}
130130

131131
auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
132132
loc, reshapeOutputType, lowerTensorValue,
133133
rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
134134

135135
if (input1Rank > input2Rank) {
136-
outInput1 = higherTensorValue;
137-
outInput2 = reshapeLower.getResult();
136+
input1 = higherTensorValue;
137+
input2 = reshapeLower.getResult();
138138
} else {
139-
outInput1 = reshapeLower.getResult();
140-
outInput2 = higherTensorValue;
139+
input1 = reshapeLower.getResult();
140+
input2 = higherTensorValue;
141141
}
142142

143143
return success();
144144
}
145145

146146
namespace {
147-
template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
147+
template <typename OpTy>
148+
struct ConvertTosaOp : public OpRewritePattern<OpTy> {
148149
using OpRewritePattern<OpTy>::OpRewritePattern;
149150

150151
LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
@@ -158,14 +159,12 @@ template <typename OpTy> struct ConvertTosaOp : public OpRewritePattern<OpTy> {
158159
if (!outputType)
159160
return failure();
160161

161-
Value outInput1, outInput2;
162162
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
163-
input1, input2, outInput1, outInput2)
163+
input1, input2)
164164
.failed())
165165
return failure();
166166

167-
rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
168-
outInput2);
167+
rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
169168

170169
return success();
171170
}
@@ -188,14 +187,13 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
188187
if (!outputType)
189188
return failure();
190189

191-
Value outInput1, outInput2;
192190
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
193-
input1, input2, outInput1, outInput2)
191+
input1, input2)
194192
.failed())
195193
return failure();
196194

197-
rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
198-
outInput1, outInput2, shift);
195+
rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
196+
input2, shift);
199197

200198
return success();
201199
}
@@ -220,14 +218,63 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
220218
if (!outputType)
221219
return failure();
222220

223-
Value outInput1, outInput2;
224221
if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
225-
input1, input2, outInput1, outInput2)
222+
input1, input2)
226223
.failed())
227224
return failure();
228225

229226
rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
230-
tosaBinaryOp, outputType, outInput1, outInput2, round);
227+
tosaBinaryOp, outputType, input1, input2, round);
228+
229+
return success();
230+
}
231+
};
232+
233+
template <>
234+
struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
235+
using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
236+
237+
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
238+
PatternRewriter &rewriter) const override {
239+
240+
Value input1 = tosaOp.getPred();
241+
Value input2 = tosaOp.getOnTrue();
242+
Value input3 = tosaOp.getOnFalse();
243+
Value output = tosaOp.getResult();
244+
245+
auto outputType = output.getType().dyn_cast<RankedTensorType>();
246+
if (!outputType)
247+
return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
248+
249+
// Apply broadcasting to each pair of inputs separately, and chain them as
250+
// compound as below so that the broadcasting happens all at once.
251+
bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
252+
input1, input2)
253+
.succeeded();
254+
255+
bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
256+
input1, input3)
257+
.succeeded();
258+
259+
bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
260+
input2, input3)
261+
.succeeded();
262+
263+
if (!reshaped1 && !reshaped2 && !reshaped3)
264+
return rewriter.notifyMatchFailure(
265+
tosaOp,
266+
"cannot rewrite as the rank of all operands is already aligned");
267+
268+
int32_t result1Rank = input1.getType().cast<RankedTensorType>().getRank();
269+
int32_t result2Rank = input2.getType().cast<RankedTensorType>().getRank();
270+
int32_t result3Rank = input3.getType().cast<RankedTensorType>().getRank();
271+
272+
if ((result1Rank != result2Rank) || (result2Rank != result3Rank))
273+
return rewriter.notifyMatchFailure(
274+
tosaOp, "not all ranks are aligned with each other");
275+
276+
rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
277+
input2, input3);
231278

232279
return success();
233280
}
@@ -263,6 +310,7 @@ struct TosaMakeBroadcastable
263310
patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
264311
patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
265312
patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
313+
patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
266314
patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
267315
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
268316
}

mlir/test/Dialect/Tosa/broadcast.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,91 @@ func.func @test_broadcast_scalar(%arg0: tensor<i32>, %arg1: tensor<17x16x15x14xi
195195
%0 = "tosa.add"(%arg0, %arg1) : (tensor<i32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
196196
return %0 : tensor<17x16x15x14xi32>
197197
}
198+
199+
// -----
200+
// CHECK-LABEL: broadcast_select_both_input
201+
func.func @test_broadcast_select_both_input(%arg0: tensor<1x16x16xi1>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<1x16x16xf32> {
202+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1>}
203+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1>}
204+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
205+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x16x16xi1>, tensor<f32>, tensor<f32>) -> tensor<1x16x16xf32>
206+
return %0 : tensor<1x16x16xf32>
207+
}
208+
209+
// -----
210+
// CHECK-LABEL: broadcast_select_one_input
211+
func.func @test_broadcast_select_one_input(%arg0: tensor<17x16x15x14xi1>, %arg1: tensor<17x16x15x14xf32>, %arg2: tensor<f32>) -> tensor<17x16x15x14xf32> {
212+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
213+
// CHECK: %[[VAL_1:.*]] = "tosa.select"(%arg0, %arg1, %[[VAL_0]])
214+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<17x16x15x14xi1>, tensor<17x16x15x14xf32>, tensor<f32>) -> tensor<17x16x15x14xf32>
215+
return %0 : tensor<17x16x15x14xf32>
216+
}
217+
218+
// -----
219+
// CHECK-LABEL: broadcast_select_predicate
220+
func.func @test_broadcast_select_predicate(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
221+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
222+
// CHECK: %[[VAL_1:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %arg2)
223+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
224+
return %0 : tensor<1x32x32x8xf32>
225+
}
226+
227+
// -----
228+
// CHECK-LABEL: broadcast_select_abc
229+
func.func @test_broadcast_select_abc(%arg0: tensor<i1>, %arg1: tensor<32x8xf32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
230+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
231+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
232+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
233+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<32x8xf32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
234+
return %0 : tensor<1x32x32x8xf32>
235+
}
236+
237+
// -----
238+
// CHECK-LABEL: broadcast_select_acb
239+
func.func @test_broadcast_select_acb(%arg0: tensor<i1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
240+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 1, 1>}
241+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
242+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
243+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<1x32x32x8xf32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
244+
return %0 : tensor<1x32x32x8xf32>
245+
}
246+
247+
// -----
248+
// CHECK-LABEL: broadcast_select_bac
249+
func.func @test_broadcast_select_bac(%arg0: tensor<32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
250+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
251+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
252+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %[[VAL_1]], %arg2)
253+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<f32>, tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32>
254+
return %0 : tensor<1x32x32x8xf32>
255+
}
256+
257+
// -----
258+
// CHECK-LABEL: broadcast_select_bca
259+
func.func @test_broadcast_select_bca(%arg0: tensor<32x8xi1>, %arg1: tensor<1x32x32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
260+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg0) {new_shape = array<i64: 1, 1, 32, 8>}
261+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
262+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%[[VAL_0]], %arg1, %[[VAL_1]])
263+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<32x8xi1>, tensor<1x32x32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
264+
return %0 : tensor<1x32x32x8xf32>
265+
}
266+
267+
// -----
268+
// CHECK-LABEL: broadcast_select_cab
269+
func.func @test_broadcast_select_cab(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<f32>, %arg2: tensor<32x8xf32>) -> tensor<1x32x32x8xf32> {
270+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 1, 1>}
271+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 32, 8>}
272+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
273+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<f32>, tensor<32x8xf32>) -> tensor<1x32x32x8xf32>
274+
return %0 : tensor<1x32x32x8xf32>
275+
}
276+
277+
// -----
278+
// CHECK-LABEL: broadcast_select_cba
279+
func.func @test_broadcast_select_cba(%arg0: tensor<1x32x32x8xi1>, %arg1: tensor<32x8xf32>, %arg2: tensor<i1>) -> tensor<1x32x32x8xf32> {
280+
// CHECK-DAG: %[[VAL_0:.*]] = "tosa.reshape"(%arg1) {new_shape = array<i64: 1, 1, 32, 8>}
281+
// CHECK-DAG: %[[VAL_1:.*]] = "tosa.reshape"(%arg2) {new_shape = array<i64: 1, 1, 1, 1>}
282+
// CHECK: %[[VAL_2:.*]] = "tosa.select"(%arg0, %[[VAL_0]], %[[VAL_1]])
283+
%0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x32x32x8xi1>, tensor<32x8xf32>, tensor<i1>) -> tensor<1x32x32x8xf32>
284+
return %0 : tensor<1x32x32x8xf32>
285+
}

0 commit comments

Comments
 (0)