@@ -164,92 +164,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
164
164
results.add <NoOpOptimization>(context);
165
165
}
166
166
167
- struct AddZeroOptimization : public OpRewritePattern <tosa::AddOp> {
168
- using OpRewritePattern::OpRewritePattern;
169
-
170
- LogicalResult matchAndRewrite (tosa::AddOp op,
171
- PatternRewriter &rewriter) const override {
172
- auto input1 = op.getInput1 ();
173
- auto input2 = op.getInput2 ();
174
-
175
- DenseElementsAttr input1Attr;
176
- if (matchPattern (input1, m_Constant (&input1Attr)) && input1Attr.isSplat () &&
177
- input2.getType () == op.getType ()) {
178
- if (input1Attr.getType ().getElementType ().isa <IntegerType>() &&
179
- input1Attr.getSplatValue <APInt>().isZero ()) {
180
- rewriter.replaceOp (op, op.getInput2 ());
181
- return success ();
182
- }
183
- }
184
-
185
- DenseElementsAttr input2Attr;
186
- if (matchPattern (input2, m_Constant (&input2Attr)) && input2Attr.isSplat () &&
187
- input1.getType () == op.getType ()) {
188
- if (input2Attr.getType ().getElementType ().isa <IntegerType>() &&
189
- input2Attr.getSplatValue <APInt>().isZero ()) {
190
- rewriter.replaceOp (op, op.getInput1 ());
191
- return success ();
192
- }
193
- }
194
-
195
- return failure ();
196
- }
197
- };
198
-
199
- void AddOp::getCanonicalizationPatterns (RewritePatternSet &results,
200
- MLIRContext *context) {
201
- results.add <AddZeroOptimization>(context);
202
- }
203
-
204
- struct MulOneOptimization : public OpRewritePattern <tosa::MulOp> {
205
- using OpRewritePattern::OpRewritePattern;
206
-
207
- LogicalResult matchAndRewrite (tosa::MulOp op,
208
- PatternRewriter &rewriter) const override {
209
- auto input1 = op.getInput1 ();
210
- auto input2 = op.getInput2 ();
211
-
212
- DenseElementsAttr input1Attr;
213
- if (matchPattern (input1, m_Constant (&input1Attr)) && input1Attr.isSplat () &&
214
- input2.getType () == op.getType ()) {
215
- if (input1Attr.getType ().getElementType ().isa <FloatType>() &&
216
- input1Attr.getSplatValue <APFloat>().isExactlyValue (1 )) {
217
- rewriter.replaceOp (op, op.getInput2 ());
218
- return success ();
219
- }
220
-
221
- if (input1Attr.getType ().getElementType ().isa <IntegerType>() &&
222
- matchPattern (input1, m_One ())) {
223
- rewriter.replaceOp (op, op.getInput2 ());
224
- return success ();
225
- }
226
- }
227
-
228
- DenseElementsAttr input2Attr;
229
- if (matchPattern (input2, m_Constant (&input2Attr)) && input2Attr.isSplat () &&
230
- input1.getType () == op.getType ()) {
231
- if (input2Attr.getType ().getElementType ().isa <FloatType>() &&
232
- input2Attr.getSplatValue <APFloat>().isExactlyValue (1 )) {
233
- rewriter.replaceOp (op, op.getInput1 ());
234
- return success ();
235
- }
236
-
237
- if (input2Attr.getType ().getElementType ().isa <IntegerType>() &&
238
- matchPattern (input2, m_One ())) {
239
- rewriter.replaceOp (op, op.getInput1 ());
240
- return success ();
241
- }
242
- }
243
-
244
- return failure ();
245
- }
246
- };
247
-
248
- void MulOp::getCanonicalizationPatterns (RewritePatternSet &results,
249
- MLIRContext *context) {
250
- results.add <MulOneOptimization>(context);
251
- }
252
-
253
167
struct MaterializePadValue : public OpRewritePattern <tosa::PadOp> {
254
168
using OpRewritePattern::OpRewritePattern;
255
169
@@ -468,64 +382,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
468
382
return {};
469
383
}
470
384
385
+ static bool isSplatZero (Type elemType, DenseElementsAttr val) {
386
+ if (elemType.isa <FloatType>())
387
+ return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
388
+ if (elemType.isa <IntegerType>())
389
+ return val && val.isSplat () && val.getSplatValue <APInt>().isZero ();
390
+ return false ;
391
+ }
392
+
393
+ static bool isSplatOne (Type elemType, DenseElementsAttr val, int64_t shift) {
394
+ if (elemType.isa <FloatType>())
395
+ return val && val.isSplat () &&
396
+ val.getSplatValue <APFloat>().isExactlyValue (1.0 );
397
+ if (elemType.isa <IntegerType>()) {
398
+ const int64_t shifted = 1LL << shift;
399
+ return val && val.isSplat () &&
400
+ val.getSplatValue <APInt>().getSExtValue () == shifted;
401
+ }
402
+ return false ;
403
+ }
404
+
471
405
OpFoldResult AddOp::fold (ArrayRef<Attribute> operands) {
472
406
auto lhsTy = getInput1 ().getType ().dyn_cast <RankedTensorType>();
473
407
auto rhsTy = getInput2 ().getType ().dyn_cast <RankedTensorType>();
474
408
auto resultTy = getType ().dyn_cast <RankedTensorType>();
475
409
if (!lhsTy || !rhsTy || !resultTy)
476
410
return {};
477
-
411
+
478
412
auto resultETy = resultTy.getElementType ();
479
413
auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
480
414
auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
481
415
482
- if (lhsTy == resultTy) {
483
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
484
- if (rhsAttr.getSplatValue <APFloat>().isZero ())
485
- return getInput1 ();
486
- }
487
- }
488
-
489
- if (lhsTy != rhsTy) {
490
- if (lhsAttr && rhsAttr) {
491
- if (lhsTy == resultTy && rhsAttr.isSplat ()) {
492
- APFloat r = rhsAttr.getSplatValue <APFloat>();
493
- std::vector<APFloat> v;
494
- v.resize (lhsAttr.size (), APFloat (0.0 ));
495
- for (int i=0 ;i<lhsAttr.size (); ++i) {
496
- v[i] = lhsAttr.getValues <APFloat>()[i] + r;
497
- }
498
- return DenseElementsAttr::get (resultTy, v);
499
- }
500
- }
501
- }
502
-
503
-
504
- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <FloatType>()) {
505
- if (lhsAttr.getSplatValue <APFloat>().isZero ())
506
- return getInput2 ();
507
- }
508
-
509
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
510
- if (rhsAttr.getSplatValue <APFloat>().isZero ())
511
- return getInput1 ();
512
- }
513
-
514
- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
515
- if (lhsAttr.getSplatValue <APInt>().isZero ())
516
- return getInput2 ();
517
- }
518
-
519
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
520
- if (rhsAttr.getSplatValue <APInt>().isZero ())
521
- return getInput1 ();
522
- }
416
+ if (lhsTy == resultTy && isSplatZero (resultETy, rhsAttr))
417
+ return getInput1 ();
418
+ if (rhsTy == resultTy && isSplatZero (resultETy, lhsAttr))
419
+ return getInput2 ();
523
420
524
421
if (!lhsAttr || !rhsAttr)
525
422
return {};
526
423
527
424
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
528
- lhsTy );
425
+ resultTy );
529
426
}
530
427
531
428
OpFoldResult DivOp::fold (ArrayRef<Attribute> operands) {
@@ -603,50 +500,26 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
603
500
auto resultTy = getType ().dyn_cast <RankedTensorType>();
604
501
if (!lhsTy || !rhsTy || !resultTy)
605
502
return {};
606
- if (lhsTy != rhsTy)
607
- return {};
608
503
609
504
auto resultETy = resultTy.getElementType ();
610
505
auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
611
506
auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
612
507
613
- if (lhsAttr && lhsAttr. isSplat () && resultETy.isa <FloatType >()) {
614
- auto val = lhsAttr. getSplatValue <APFloat>();
615
- if (val. isZero ( ))
508
+ const int64_t shift = resultETy.isa <IntegerType >() ? getShift () : 0 ;
509
+ if (rhsTy == resultTy) {
510
+ if (isSplatZero (resultETy, lhsAttr ))
616
511
return lhsAttr;
617
- if (val. isExactlyValue ( 1.0 ))
512
+ if (isSplatOne (resultETy, lhsAttr, shift ))
618
513
return rhs;
619
514
}
620
-
621
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
622
- auto val = rhsAttr.getSplatValue <APFloat>();
623
- if (val.isZero ())
624
- return rhsAttr;
625
- if (val.isExactlyValue (1.0 ))
626
- return lhs;
627
- }
628
-
629
- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
630
- auto val = lhsAttr.getSplatValue <APInt>();
631
- if (val.isZero ())
632
- return lhsAttr;
633
- const int64_t shift = getShift ();
634
- const int64_t shifted = 1LL << shift;
635
- if (val.getSExtValue () == shifted)
636
- return rhs;
637
- }
638
-
639
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
640
- auto val = rhsAttr.getSplatValue <APInt>();
641
- const int64_t shift = getShift ();
642
- const int64_t shifted = 1LL << shift;
643
- if (val.isZero ())
515
+ if (lhsTy == resultTy) {
516
+ if (isSplatZero (resultETy, rhsAttr))
644
517
return rhsAttr;
645
- if (val. getSExtValue () == shifted )
518
+ if (isSplatOne (resultETy, rhsAttr, shift) )
646
519
return lhs;
647
520
}
648
521
649
- return mulBinaryFolder (lhsAttr, rhsAttr, lhsTy , getShift ());
522
+ return mulBinaryFolder (lhsAttr, rhsAttr, resultTy , getShift ());
650
523
}
651
524
652
525
OpFoldResult SubOp::fold (ArrayRef<Attribute> operands) {
@@ -655,28 +528,18 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
655
528
auto resultTy = getType ().dyn_cast <RankedTensorType>();
656
529
if (!lhsTy || !rhsTy || !resultTy)
657
530
return {};
658
- if (lhsTy != rhsTy)
659
- return {};
660
531
661
532
auto resultETy = resultTy.getElementType ();
662
533
auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
663
534
auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
664
-
665
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
666
- if (rhsAttr.getSplatValue <APFloat>().isZero ())
667
- return getInput1 ();
668
- }
669
-
670
- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
671
- if (rhsAttr.getSplatValue <APInt>().isZero ())
672
- return getInput1 ();
673
- }
535
+ if (lhsTy == resultTy && isSplatZero (resultETy, rhsAttr))
536
+ return getInput1 ();
674
537
675
538
if (!lhsAttr || !rhsAttr)
676
539
return {};
677
540
678
541
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
679
- lhsTy );
542
+ resultTy );
680
543
}
681
544
682
545
namespace {
@@ -917,7 +780,7 @@ OpFoldResult RsqrtOp::fold(FoldAdaptor adaptor) {
917
780
auto operand = adaptor.getInput1().dyn_cast_or_null<ElementsAttr>();
918
781
if (!operand)
919
782
return {};
920
-
783
+
921
784
if (!inputTy.getElementType().isF32())
922
785
return {};
923
786
@@ -947,7 +810,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
947
810
auto operand2 = adaptor.getInput2().dyn_cast_or_null<ElementsAttr>();
948
811
if (!operand2)
949
812
return {};
950
-
813
+
951
814
if (!operand1.getElementType().isF32() || !operand2.getElementType().isF32())
952
815
return {};
953
816
@@ -961,7 +824,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
961
824
962
825
OpFoldResult ReciprocalOp::fold(FoldAdaptor adaptor) {
963
826
auto src = adaptor.getInput1().dyn_cast_or_null<mlir::DenseElementsAttr>();
964
-
827
+
965
828
if (!src)
966
829
return nullptr;
967
830
@@ -989,7 +852,6 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
989
852
return {};
990
853
}
991
854
992
-
993
855
OpFoldResult SliceOp::fold (ArrayRef<Attribute> operands) {
994
856
auto inputTy = getInput ().getType ().dyn_cast <RankedTensorType>();
995
857
auto outputTy = getType ().dyn_cast <RankedTensorType>();
0 commit comments