Skip to content

Commit 2b2ebb6

Browse files
committed
[mlir][tosa] Add folders for trivial tosa operation cases
Some folding cases are trivial to fold away, specifically no-op cases where an operation's input and output are the same. Canonicalizing these away removes unneeded operations. The current version includes tensor cast operations to resolve shape discreprencies that occur when an operation's result type differs from the input type. These are resolved during a tosa shape propagation pass. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D107321
1 parent 86858c6 commit 2b2ebb6

File tree

6 files changed

+359
-19
lines changed

6 files changed

+359
-19
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def Tosa_Dialect : Dialect {
3737
there will be tools to lower from the ML frameworks into TOSA.
3838
}];
3939

40+
let dependentDialects = ["tensor::TensorDialect"];
41+
4042
let cppNamespace = "mlir::tosa";
4143
let hasConstantMaterializer = 1;
4244
}

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
1515

1616
#include "mlir/Dialect/Quant/QuantOps.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Traits.h"
1819
#include "mlir/Interfaces/InferTypeOpInterface.h"
1920
#include "mlir/Interfaces/LoopLikeInterface.h"

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [
12271227
let results = (outs
12281228
Tosa_Tensor1Dto4D:$output
12291229
);
1230+
1231+
let hasFolder = 1;
12301232
}
12311233

12321234
//===----------------------------------------------------------------------===//
@@ -1250,6 +1252,8 @@ def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [
12501252
let results = (outs
12511253
Tosa_Tensor1Dto4D:$output
12521254
);
1255+
1256+
let hasFolder = 1;
12531257
}
12541258

12551259
//===----------------------------------------------------------------------===//
@@ -1273,6 +1277,8 @@ def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [
12731277
let results = (outs
12741278
Tosa_Tensor1Dto4D:$output
12751279
);
1280+
1281+
let hasFolder = 1;
12761282
}
12771283

12781284
//===----------------------------------------------------------------------===//
@@ -1296,6 +1302,8 @@ def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [
12961302
let results = (outs
12971303
Tosa_Tensor1Dto4D:$output
12981304
);
1305+
1306+
let hasFolder = 1;
12991307
}
13001308

13011309
//===----------------------------------------------------------------------===//
@@ -1319,6 +1327,8 @@ def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [
13191327
let results = (outs
13201328
Tosa_Tensor1Dto4D:$output
13211329
);
1330+
1331+
let hasFolder = 1;
13221332
}
13231333

13241334
//===----------------------------------------------------------------------===//
@@ -1342,6 +1352,8 @@ def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [
13421352
let results = (outs
13431353
Tosa_Tensor1Dto4D:$output
13441354
);
1355+
1356+
let hasFolder = 1;
13451357
}
13461358

13471359
//===----------------------------------------------------------------------===//
@@ -1371,6 +1383,8 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [
13711383
let results = (outs
13721384
Tosa_RankedTensor:$output
13731385
);
1386+
1387+
let hasCanonicalizer = 1;
13741388
}
13751389

13761390
//===----------------------------------------------------------------------===//
@@ -1415,6 +1429,7 @@ def Tosa_ReshapeOp: Tosa_Op<"reshape", [
14151429
}];
14161430

14171431
let hasCanonicalizer = 1;
1432+
let hasFolder = 1;
14181433

14191434
let arguments = (ins
14201435
Tosa_Tensor:$input1,
@@ -1473,6 +1488,8 @@ def Tosa_SliceOp: Tosa_Op<"slice", [
14731488
let results = (outs
14741489
Tosa_Tensor1Dto6D:$output
14751490
);
1491+
1492+
let hasFolder = 1;
14761493
}
14771494

14781495
//===----------------------------------------------------------------------===//
@@ -1495,6 +1512,8 @@ def Tosa_TileOp: Tosa_Op<"tile", [
14951512
let results = (outs
14961513
Tosa_Tensor1Dto4D:$output
14971514
);
1515+
1516+
let hasFolder = 1;
14981517
}
14991518

15001519
//===----------------------------------------------------------------------===//
@@ -1518,6 +1537,8 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [
15181537
let results = (
15191538
outs Tosa_Tensor1Dto6D:$output
15201539
);
1540+
1541+
let hasFolder = 1;
15211542
}
15221543

15231544
//===----------------------------------------------------------------------===//
@@ -1655,6 +1676,8 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect,
16551676
let results = (outs
16561677
Tosa_Tensor:$output
16571678
);
1679+
1680+
let hasFolder = 1;
16581681
}
16591682

16601683
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1616
#include "mlir/Dialect/StandardOps/IR/Ops.h"
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
1819
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
1920
#include "mlir/IR/BuiltinTypes.h"
@@ -107,19 +108,31 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
107108
// Operator Canonicalizers.
108109
//===----------------------------------------------------------------------===//
109110

110-
struct RemoveReshapeNoop : public OpRewritePattern<tosa::ReshapeOp> {
111-
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
111+
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
112+
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
112113

113-
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
114+
LogicalResult matchAndRewrite(tosa::ConcatOp op,
114115
PatternRewriter &rewriter) const override {
115-
if (op.input1().getType() != op.getType())
116+
if (op.input1().size() != 1)
116117
return failure();
118+
if (op.input1().front().getType() != op.getType()) {
119+
rewriter
120+
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
121+
op.input1().front())
122+
.getResult();
123+
return success();
124+
}
117125

118-
rewriter.replaceOp(op, op.input1());
126+
rewriter.replaceOp(op, op.input1().front());
119127
return success();
120128
}
121129
};
122130

131+
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
132+
MLIRContext *context) {
133+
results.insert<ConcatOptimization>(context);
134+
}
135+
123136
struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
124137
using OpRewritePattern<tosa::ReshapeOp>::OpRewritePattern;
125138

@@ -142,18 +155,88 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
142155

143156
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
144157
MLIRContext *context) {
145-
results.insert<ReshapeReshapeOptimization, RemoveReshapeNoop>(context);
158+
results.insert<ReshapeReshapeOptimization>(context);
146159
}
147160

148161
//===----------------------------------------------------------------------===//
149162
// Operator Folders.
150163
//===----------------------------------------------------------------------===//
151164

165+
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
166+
if (input().getType() == getType())
167+
return input();
168+
return {};
169+
}
170+
152171
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
153172
assert(operands.empty() && "constant has no operands");
154173
return valueAttr();
155174
}
156175

176+
#define ReduceFolder(OP) \
177+
OpFoldResult OP::fold(ArrayRef<Attribute> operands) { \
178+
ShapedType inputTy = input().getType().cast<ShapedType>(); \
179+
if (!inputTy.hasRank()) \
180+
return {}; \
181+
if (inputTy.getDimSize(axis()) == 1) \
182+
return input(); \
183+
return {}; \
184+
}
185+
186+
ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
187+
ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
188+
ReduceFolder(ReduceSumOp)
189+
#undef ReduceFolder
190+
191+
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
192+
auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
193+
auto outputTy = getType().dyn_cast<RankedTensorType>();
194+
195+
if (!inputTy || !outputTy || inputTy != outputTy)
196+
return {};
197+
return input1();
198+
}
199+
200+
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
201+
auto inputTy = input().getType().dyn_cast<RankedTensorType>();
202+
auto outputTy = getType().dyn_cast<RankedTensorType>();
203+
204+
if (!inputTy || !outputTy || inputTy != outputTy)
205+
return {};
206+
if (inputTy.hasStaticShape())
207+
return input();
208+
209+
return {};
210+
}
211+
212+
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
213+
bool allOnes = true;
214+
for (Attribute val : multiples().getValue()) {
215+
allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
216+
}
217+
218+
if (allOnes && input1().getType() == getType())
219+
return input1();
220+
return {};
221+
}
222+
223+
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
224+
if (!operands[1])
225+
return {};
226+
227+
DenseIntElementsAttr perms = operands[1].cast<DenseIntElementsAttr>();
228+
229+
bool isRange = true;
230+
for (auto it : llvm::enumerate(perms)) {
231+
isRange = isRange &&
232+
it.value().getSExtValue() == static_cast<int64_t>(it.index());
233+
}
234+
235+
if (isRange && input1().getType() == getType())
236+
return input1();
237+
return {};
238+
}
239+
157240
//===----------------------------------------------------------------------===//
158241
// TOSA Operator Verifiers.
159242
//===----------------------------------------------------------------------===//

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,9 @@ func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
237237
// CHECK: fptrunc
238238
%23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16>
239239

240-
// CHECK: linalg.generic
241-
// CHECK: yield
242-
%24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32>
243-
244240
// CHECK: linalg.generic
245241
// CHECK: divf
246-
%25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
242+
%24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32>
247243

248244
return
249245
}
@@ -383,29 +379,25 @@ func @test_simple_i32(%arg0: tensor<1xi32>) -> () {
383379
// CHECK: trunci
384380
%20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>
385381

386-
// CHECK: linalg.generic
387-
// CHECK: yield
388-
%21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>
389-
390382
// CHECK: linalg.generic
391383
// CHECK: sexti
392-
%22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
384+
%21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64>
393385

394386
// CHECK: linalg.generic
395387
// CHECK: constant 0
396388
// CHECK: cmpi
397-
%23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
389+
%22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1>
398390

399391
// CHECK: linalg.generic
400392
// CHECK: sitofp
401-
%24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
393+
%23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32>
402394

403395
// CHECK: linalg.generic
404396
// CHECK: constant 0
405397
// CHECK: cmpi sgt
406398
// CHECK: subi
407399
// CHECK: select
408-
%25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
400+
%24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
409401

410402
return
411403
}

0 commit comments

Comments
 (0)