@@ -662,19 +662,62 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
662
662
return Attribute ();
663
663
}
664
664
665
+ // ===----------------------------------------------------------------------===//
666
+ // spirv.LogicalEqualOp
667
+ // ===----------------------------------------------------------------------===//
668
+
669
+ OpFoldResult
670
+ spirv::LogicalEqualOp::fold (spirv::LogicalEqualOp::FoldAdaptor adaptor) {
671
+ // x == x -> true
672
+ if (getOperand1 () == getOperand2 ()) {
673
+ auto type = getType ();
674
+ if (isa<IntegerType>(type)) {
675
+ return BoolAttr::get (getContext (), true );
676
+ }
677
+ if (isa<VectorType>(type)) {
678
+ auto vtType = cast<ShapedType>(type);
679
+ auto element = BoolAttr::get (getContext (), true );
680
+ return DenseElementsAttr::get (vtType, element);
681
+ }
682
+ }
683
+
684
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (),
685
+ [](const APInt &a, const APInt &b) {
686
+ APInt zero = APInt::getZero (1 );
687
+ return a == b ? (zero + 1 ) : zero;
688
+ });
689
+ }
690
+
665
691
// ===----------------------------------------------------------------------===//
666
692
// spirv.LogicalNotEqualOp
667
693
// ===----------------------------------------------------------------------===//
668
694
669
695
OpFoldResult spirv::LogicalNotEqualOp::fold (FoldAdaptor adaptor) {
670
696
if (std::optional<bool > rhs =
671
697
getScalarOrSplatBoolAttr (adaptor.getOperand2 ())) {
672
- // x && false = x
698
+ // x != false -> x
673
699
if (!rhs.value ())
674
700
return getOperand1 ();
675
701
}
676
702
677
- return Attribute ();
703
+ // x == x -> false
704
+ if (getOperand1 () == getOperand2 ()) {
705
+ auto type = getType ();
706
+ if (isa<IntegerType>(type)) {
707
+ return BoolAttr::get (getContext (), false );
708
+ }
709
+ if (isa<VectorType>(type)) {
710
+ auto vtType = cast<ShapedType>(type);
711
+ auto element = BoolAttr::get (getContext (), false );
712
+ return DenseElementsAttr::get (vtType, element);
713
+ }
714
+ }
715
+
716
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (),
717
+ [](const APInt &a, const APInt &b) {
718
+ APInt zero = APInt::getZero (1 );
719
+ return a == b ? zero : (zero + 1 );
720
+ });
678
721
}
679
722
680
723
// ===----------------------------------------------------------------------===//
@@ -709,6 +752,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
709
752
return Attribute ();
710
753
}
711
754
755
+ // ===----------------------------------------------------------------------===//
756
+ // spirv.IEqualOp
757
+ // ===----------------------------------------------------------------------===//
758
+
759
+ OpFoldResult spirv::IEqualOp::fold (spirv::IEqualOp::FoldAdaptor adaptor) {
760
+ // x == x -> true
761
+ if (getOperand1 () == getOperand2 ()) {
762
+ auto type = getType ();
763
+ if (isa<IntegerType>(type)) {
764
+ return BoolAttr::get (getContext (), true );
765
+ }
766
+ if (isa<VectorType>(type)) {
767
+ auto vtType = cast<ShapedType>(type);
768
+ auto element = BoolAttr::get (getContext (), true );
769
+ return DenseElementsAttr::get (vtType, element);
770
+ }
771
+ }
772
+
773
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (), getType (),
774
+ [](const APInt &a, const APInt &b) {
775
+ APInt zero = APInt::getZero (1 );
776
+ return a == b ? (zero + 1 ) : zero;
777
+ });
778
+ }
779
+
780
+ // ===----------------------------------------------------------------------===//
781
+ // spirv.INotEqualOp
782
+ // ===----------------------------------------------------------------------===//
783
+
784
+ OpFoldResult spirv::INotEqualOp::fold (spirv::INotEqualOp::FoldAdaptor adaptor) {
785
+ // x == x -> false
786
+ if (getOperand1 () == getOperand2 ()) {
787
+ auto type = getType ();
788
+ if (isa<IntegerType>(type)) {
789
+ return BoolAttr::get (getContext (), false );
790
+ }
791
+ if (isa<VectorType>(type)) {
792
+ auto vtType = cast<ShapedType>(type);
793
+ auto element = BoolAttr::get (getContext (), false );
794
+ return DenseElementsAttr::get (vtType, element);
795
+ }
796
+ }
797
+
798
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (), getType (),
799
+ [](const APInt &a, const APInt &b) {
800
+ APInt zero = APInt::getZero (1 );
801
+ return a == b ? zero : (zero + 1 );
802
+ });
803
+ }
804
+
712
805
// ===----------------------------------------------------------------------===//
713
806
// spirv.ShiftLeftLogical
714
807
// ===----------------------------------------------------------------------===//
0 commit comments