Skip to content

Commit dddeab1

Browse files
authored
Merge pull request #18 from Xilinx/tina.tosacherrypicksubfold
Enhanced splat folding for x+/-0 and x*1 in TOSA
2 parents 40037f4 + 0ccd387 commit dddeab1

File tree

3 files changed

+55
-194
lines changed

3 files changed

+55
-194
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,6 @@ def Tosa_AddOp : Tosa_Op<"add", [
419419
Tosa_Tensor:$output
420420
);
421421

422-
let hasCanonicalizer = 1;
423422
let hasFolder = 1;
424423
}
425424

@@ -738,7 +737,6 @@ def Tosa_MulOp : Tosa_Op<"mul", [
738737
Tosa_Tensor:$output
739738
);
740739

741-
let hasCanonicalizer = 1;
742740
let hasFolder = 1;
743741
}
744742

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 40 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -164,92 +164,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
164164
results.add<NoOpOptimization>(context);
165165
}
166166

167-
struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
168-
using OpRewritePattern::OpRewritePattern;
169-
170-
LogicalResult matchAndRewrite(tosa::AddOp op,
171-
PatternRewriter &rewriter) const override {
172-
auto input1 = op.getInput1();
173-
auto input2 = op.getInput2();
174-
175-
DenseElementsAttr input1Attr;
176-
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
177-
input2.getType() == op.getType()) {
178-
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
179-
input1Attr.getSplatValue<APInt>().isZero()) {
180-
rewriter.replaceOp(op, op.getInput2());
181-
return success();
182-
}
183-
}
184-
185-
DenseElementsAttr input2Attr;
186-
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
187-
input1.getType() == op.getType()) {
188-
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
189-
input2Attr.getSplatValue<APInt>().isZero()) {
190-
rewriter.replaceOp(op, op.getInput1());
191-
return success();
192-
}
193-
}
194-
195-
return failure();
196-
}
197-
};
198-
199-
void AddOp::getCanonicalizationPatterns(RewritePatternSet &results,
200-
MLIRContext *context) {
201-
results.add<AddZeroOptimization>(context);
202-
}
203-
204-
struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
205-
using OpRewritePattern::OpRewritePattern;
206-
207-
LogicalResult matchAndRewrite(tosa::MulOp op,
208-
PatternRewriter &rewriter) const override {
209-
auto input1 = op.getInput1();
210-
auto input2 = op.getInput2();
211-
212-
DenseElementsAttr input1Attr;
213-
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
214-
input2.getType() == op.getType()) {
215-
if (input1Attr.getType().getElementType().isa<FloatType>() &&
216-
input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
217-
rewriter.replaceOp(op, op.getInput2());
218-
return success();
219-
}
220-
221-
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
222-
matchPattern(input1, m_One())) {
223-
rewriter.replaceOp(op, op.getInput2());
224-
return success();
225-
}
226-
}
227-
228-
DenseElementsAttr input2Attr;
229-
if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
230-
input1.getType() == op.getType()) {
231-
if (input2Attr.getType().getElementType().isa<FloatType>() &&
232-
input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
233-
rewriter.replaceOp(op, op.getInput1());
234-
return success();
235-
}
236-
237-
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
238-
matchPattern(input2, m_One())) {
239-
rewriter.replaceOp(op, op.getInput1());
240-
return success();
241-
}
242-
}
243-
244-
return failure();
245-
}
246-
};
247-
248-
void MulOp::getCanonicalizationPatterns(RewritePatternSet &results,
249-
MLIRContext *context) {
250-
results.add<MulOneOptimization>(context);
251-
}
252-
253167
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
254168
using OpRewritePattern::OpRewritePattern;
255169

@@ -468,64 +382,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
468382
return {};
469383
}
470384

385+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
386+
if (elemType.isa<FloatType>())
387+
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
388+
if (elemType.isa<IntegerType>())
389+
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
390+
return false;
391+
}
392+
393+
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
394+
if (elemType.isa<FloatType>())
395+
return val && val.isSplat() &&
396+
val.getSplatValue<APFloat>().isExactlyValue(1.0);
397+
if (elemType.isa<IntegerType>()) {
398+
const int64_t shifted = 1LL << shift;
399+
return val && val.isSplat() &&
400+
val.getSplatValue<APInt>().getSExtValue() == shifted;
401+
}
402+
return false;
403+
}
404+
471405
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
472406
auto lhsTy = getInput1().getType().dyn_cast<RankedTensorType>();
473407
auto rhsTy = getInput2().getType().dyn_cast<RankedTensorType>();
474408
auto resultTy = getType().dyn_cast<RankedTensorType>();
475409
if (!lhsTy || !rhsTy || !resultTy)
476410
return {};
477-
411+
478412
auto resultETy = resultTy.getElementType();
479413
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
480414
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
481415

482-
if (lhsTy == resultTy) {
483-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
484-
if (rhsAttr.getSplatValue<APFloat>().isZero())
485-
return getInput1();
486-
}
487-
}
488-
489-
if (lhsTy != rhsTy) {
490-
if (lhsAttr && rhsAttr) {
491-
if (lhsTy == resultTy && rhsAttr.isSplat()) {
492-
APFloat r = rhsAttr.getSplatValue<APFloat>();
493-
std::vector<APFloat> v;
494-
v.resize(lhsAttr.size(), APFloat(0.0));
495-
for(int i=0;i<lhsAttr.size(); ++i) {
496-
v[i] = lhsAttr.getValues<APFloat>()[i] + r;
497-
}
498-
return DenseElementsAttr::get(resultTy, v);
499-
}
500-
}
501-
}
502-
503-
504-
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
505-
if (lhsAttr.getSplatValue<APFloat>().isZero())
506-
return getInput2();
507-
}
508-
509-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
510-
if (rhsAttr.getSplatValue<APFloat>().isZero())
511-
return getInput1();
512-
}
513-
514-
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
515-
if (lhsAttr.getSplatValue<APInt>().isZero())
516-
return getInput2();
517-
}
518-
519-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
520-
if (rhsAttr.getSplatValue<APInt>().isZero())
521-
return getInput1();
522-
}
416+
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
417+
return getInput1();
418+
if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
419+
return getInput2();
523420

524421
if (!lhsAttr || !rhsAttr)
525422
return {};
526423

527424
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
528-
lhsTy);
425+
resultTy);
529426
}
530427

531428
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
@@ -603,50 +500,26 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
603500
auto resultTy = getType().dyn_cast<RankedTensorType>();
604501
if (!lhsTy || !rhsTy || !resultTy)
605502
return {};
606-
if (lhsTy != rhsTy)
607-
return {};
608503

609504
auto resultETy = resultTy.getElementType();
610505
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
611506
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
612507

613-
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<FloatType>()) {
614-
auto val = lhsAttr.getSplatValue<APFloat>();
615-
if (val.isZero())
508+
const int64_t shift = resultETy.isa<IntegerType>() ? getShift() : 0;
509+
if (rhsTy == resultTy) {
510+
if (isSplatZero(resultETy, lhsAttr))
616511
return lhsAttr;
617-
if (val.isExactlyValue(1.0))
512+
if (isSplatOne(resultETy, lhsAttr, shift))
618513
return rhs;
619514
}
620-
621-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
622-
auto val = rhsAttr.getSplatValue<APFloat>();
623-
if (val.isZero())
624-
return rhsAttr;
625-
if (val.isExactlyValue(1.0))
626-
return lhs;
627-
}
628-
629-
if (lhsAttr && lhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
630-
auto val = lhsAttr.getSplatValue<APInt>();
631-
if (val.isZero())
632-
return lhsAttr;
633-
const int64_t shift = getShift();
634-
const int64_t shifted = 1LL << shift;
635-
if (val.getSExtValue() == shifted)
636-
return rhs;
637-
}
638-
639-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
640-
auto val = rhsAttr.getSplatValue<APInt>();
641-
const int64_t shift = getShift();
642-
const int64_t shifted = 1LL << shift;
643-
if (val.isZero())
515+
if (lhsTy == resultTy) {
516+
if (isSplatZero(resultETy, rhsAttr))
644517
return rhsAttr;
645-
if (val.getSExtValue() == shifted)
518+
if (isSplatOne(resultETy, rhsAttr, shift))
646519
return lhs;
647520
}
648521

649-
return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift());
522+
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
650523
}
651524

652525
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
@@ -655,28 +528,18 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
655528
auto resultTy = getType().dyn_cast<RankedTensorType>();
656529
if (!lhsTy || !rhsTy || !resultTy)
657530
return {};
658-
if (lhsTy != rhsTy)
659-
return {};
660531

661532
auto resultETy = resultTy.getElementType();
662533
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
663534
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
664-
665-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<FloatType>()) {
666-
if (rhsAttr.getSplatValue<APFloat>().isZero())
667-
return getInput1();
668-
}
669-
670-
if (rhsAttr && rhsAttr.isSplat() && resultETy.isa<IntegerType>()) {
671-
if (rhsAttr.getSplatValue<APInt>().isZero())
672-
return getInput1();
673-
}
535+
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
536+
return getInput1();
674537

675538
if (!lhsAttr || !rhsAttr)
676539
return {};
677540

678541
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
679-
lhsTy);
542+
resultTy);
680543
}
681544

682545
namespace {
@@ -917,7 +780,7 @@ OpFoldResult RsqrtOp::fold(FoldAdaptor adaptor) {
917780
auto operand = adaptor.getInput1().dyn_cast_or_null<ElementsAttr>();
918781
if (!operand)
919782
return {};
920-
783+
921784
if (!inputTy.getElementType().isF32())
922785
return {};
923786

@@ -947,7 +810,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
947810
auto operand2 = adaptor.getInput2().dyn_cast_or_null<ElementsAttr>();
948811
if (!operand2)
949812
return {};
950-
813+
951814
if (!operand1.getElementType().isF32() || !operand2.getElementType().isF32())
952815
return {};
953816

@@ -961,7 +824,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
961824

962825
OpFoldResult ReciprocalOp::fold(FoldAdaptor adaptor) {
963826
auto src = adaptor.getInput1().dyn_cast_or_null<mlir::DenseElementsAttr>();
964-
827+
965828
if (!src)
966829
return nullptr;
967830

@@ -989,7 +852,6 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
989852
return {};
990853
}
991854

992-
993855
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
994856
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
995857
auto outputTy = getType().dyn_cast<RankedTensorType>();

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ func.func @argmax_nofold(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
77
return %0 : tensor<?x1xf32>
88
}
99

10-
// CHECK-LABEL: @add_zero_different_shape
11-
func.func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> {
12-
// CHECK: tosa.add
13-
%zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32>
14-
%1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32>
10+
// CHECK-LABEL: @add_bcast_zero_int
11+
func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> {
12+
// CHECK-NOT: tosa.add
13+
// CHECK: return %arg0
14+
%zeros = "tosa.const"() {value = dense<0> : tensor<1x1x1xi32>} : () -> tensor<1x1x1xi32>
15+
%1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1x1xi32>) -> tensor<4x2x3xi32>
1516
return %1 : tensor<4x2x3xi32>
1617
}
1718

18-
1919
// CHECK-LABEL: @add_zero_int
2020
func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
2121
// CHECK: return %arg0
@@ -176,14 +176,6 @@ func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi3
176176
return %1 : tensor<?x?xi32>
177177
}
178178

179-
// CHECK-LABEL: @mul_one_different_shape
180-
func.func @mul_one_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
181-
// CHECK: tosa.mul
182-
%ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32>
183-
%1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32>
184-
return %1 : tensor<4x2x3xf32>
185-
}
186-
187179
// CHECK-LABEL: @mul_one_float
188180
func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
189181
// CHECK: return %arg0
@@ -193,6 +185,15 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
193185
return %1 : tensor<2x3xf32>
194186
}
195187

188+
// CHECK-LABEL: @mul_bcast_one_float
189+
func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
190+
// CHECK: return %arg0
191+
// CHECK-NOT: tosa.mul
192+
%ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
193+
%1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
194+
return %1 : tensor<2x3xf32>
195+
}
196+
196197
// CHECK-LABEL: @mul_one_int
197198
func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
198199
// CHECK: return %arg0

0 commit comments

Comments
 (0)