Skip to content

Commit 2e7aa7e

Browse files
authored
[mlir][tosa] Add custom operand getters for select op (#145921)
The select op has 3 inputs: input1, input2, input3 to according to the tosa specification. However, use of getInput1(), getInput2() and getInput3() in the codebase can be confusing and hinder readability. This commit adds custom getters to help improve readability: - input1 -> getPred() - input2 -> getOnTrue() - input3 -> getOnFalse() They should be preferred as they are more descriptive, however, the ODS generated getters (getInputX()) may still be used. Unfortunately the custom getters don't propagate to Adaptors such as `FoldAdaptor`, so the ODS generated getters must be used.
1 parent 473769e commit 2e7aa7e

File tree

5 files changed

+23
-16
lines changed

5 files changed

+23
-16
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,9 +1490,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14901490
}];
14911491

14921492
let arguments = (ins
1493-
Tosa_I1Tensor:$input1,
1494-
Tosa_Tensor:$input2,
1495-
Tosa_Tensor:$input3
1493+
Tosa_I1Tensor:$input1, // pred
1494+
Tosa_Tensor:$input2, // on true
1495+
Tosa_Tensor:$input3 // on false
14961496
);
14971497

14981498
let results = (outs
@@ -1512,6 +1512,13 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
15121512
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
15131513
`)` `->` type($output)
15141514
}];
1515+
1516+
let extraClassDeclaration = [{
1517+
// Custom getters for readability
1518+
::mlir::TypedValue<::mlir::TensorType> getPred() { return getInput1(); }
1519+
::mlir::TypedValue<::mlir::TensorType> getOnTrue() { return getInput2(); }
1520+
::mlir::TypedValue<::mlir::TensorType> getOnFalse() { return getInput3(); }
1521+
}];
15151522
}
15161523

15171524
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
344344
return failure();
345345
rewriter.modifyOpInPlace(op, [&]() {
346346
op.getOperation()->setOperands(
347-
{notOp.getInput1(), op.getInput3(), op.getInput2()});
347+
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
348348
});
349349
return success();
350350
}
@@ -1510,8 +1510,8 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
15101510
}
15111511

15121512
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
1513-
if (getInput2() == getInput3())
1514-
return getInput2();
1513+
if (getOnTrue() == getOnFalse())
1514+
return getOnTrue();
15151515

15161516
auto predicate =
15171517
llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getInput1());
@@ -1520,8 +1520,8 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
15201520

15211521
if (!predicate.isSplat())
15221522
return {};
1523-
return predicate.getSplatValue<APInt>().getBoolValue() ? getInput2()
1524-
: getInput3();
1523+
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
1524+
: getOnFalse();
15251525
}
15261526

15271527
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3829,16 +3829,16 @@ LogicalResult ReverseOp::verify() {
38293829

38303830
LogicalResult tosa::SelectOp::verify() {
38313831
// verify input2 and input3 have same element type as output
3832-
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3832+
if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
38333833
/* outType = */ getOutput().getType())
38343834
.failed() ||
3835-
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3835+
verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
38363836
/* outType = */ getOutput().getType())
38373837
.failed()) {
38383838
return failure();
38393839
}
38403840
// verify input1 has element type of bool
3841-
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3841+
auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
38423842
if (!predicateType) {
38433843
return emitOpError("expect shaped tensor for input1, got ")
38443844
<< getInput1().getType();

mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
169169
LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170170
PatternRewriter &rewriter) const override {
171171

172-
Value input1 = tosaOp.getInput1();
173-
Value input2 = tosaOp.getInput2();
174-
Value input3 = tosaOp.getInput3();
172+
Value input1 = tosaOp.getPred();
173+
Value input2 = tosaOp.getOnTrue();
174+
Value input3 = tosaOp.getOnFalse();
175175
Value output = tosaOp.getResult();
176176

177177
auto outputType = dyn_cast<RankedTensorType>(output.getType());

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
188188

189189
template <>
190190
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
191-
addValue(op.getInput2());
192-
addValue(op.getInput3());
191+
addValue(op.getOnTrue());
192+
addValue(op.getOnFalse());
193193
addValue(op.getOutput());
194194
return success();
195195
}

0 commit comments

Comments
 (0)