Skip to content

Commit d7d3026

Browse files
committed
[mlir][spirv] Add folding for [I|Logical][Not]Equal
Add missing constant propogation folder for [I|Logical][N]Eq Implement additional folding when lhs == rhs for all ops. As well as, fix test cases in logical-ops-to-llvm that failed due to introduced folding. This helps for readability of lowered code into SPIR-V. Part of work for #70704
1 parent ab41ea4 commit d7d3026

File tree

4 files changed

+276
-11
lines changed

4 files changed

+276
-11
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
369369
%5 = spirv.IEqual %2, %3 : vector<4xi32>
370370
```
371371
}];
372+
373+
let hasFolder = 1;
372374
}
373375

374376
// -----
@@ -395,6 +397,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
395397

396398
```
397399
}];
400+
401+
let hasFolder = 1;
398402
}
399403

400404
// -----
@@ -501,6 +505,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
501505
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
502506
```
503507
}];
508+
509+
let hasFolder = 1;
504510
}
505511

506512
// -----
@@ -557,7 +563,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
557563
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
558564
```
559565
}];
560-
let hasFolder = true;
566+
567+
let hasFolder = 1;
561568
}
562569

563570
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,19 +662,62 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
662662
return Attribute();
663663
}
664664

665+
//===----------------------------------------------------------------------===//
666+
// spirv.LogicalEqualOp
667+
//===----------------------------------------------------------------------===//
668+
669+
OpFoldResult
670+
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
671+
// x == x -> true
672+
if (getOperand1() == getOperand2()) {
673+
auto type = getType();
674+
if (isa<IntegerType>(type)) {
675+
return BoolAttr::get(getContext(), true);
676+
}
677+
if (isa<VectorType>(type)) {
678+
auto vtType = cast<ShapedType>(type);
679+
auto element = BoolAttr::get(getContext(), true);
680+
return DenseElementsAttr::get(vtType, element);
681+
}
682+
}
683+
684+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
685+
[](const APInt &a, const APInt &b) {
686+
APInt zero = APInt::getZero(1);
687+
return a == b ? (zero + 1) : zero;
688+
});
689+
}
690+
665691
//===----------------------------------------------------------------------===//
666692
// spirv.LogicalNotEqualOp
667693
//===----------------------------------------------------------------------===//
668694

669695
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
670696
if (std::optional<bool> rhs =
671697
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
672-
// x && false = x
698+
// x != false -> x
673699
if (!rhs.value())
674700
return getOperand1();
675701
}
676702

677-
return Attribute();
703+
// x == x -> false
704+
if (getOperand1() == getOperand2()) {
705+
auto type = getType();
706+
if (isa<IntegerType>(type)) {
707+
return BoolAttr::get(getContext(), false);
708+
}
709+
if (isa<VectorType>(type)) {
710+
auto vtType = cast<ShapedType>(type);
711+
auto element = BoolAttr::get(getContext(), false);
712+
return DenseElementsAttr::get(vtType, element);
713+
}
714+
}
715+
716+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
717+
[](const APInt &a, const APInt &b) {
718+
APInt zero = APInt::getZero(1);
719+
return a == b ? zero : (zero + 1);
720+
});
678721
}
679722

680723
//===----------------------------------------------------------------------===//
@@ -709,6 +752,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
709752
return Attribute();
710753
}
711754

755+
//===----------------------------------------------------------------------===//
756+
// spirv.IEqualOp
757+
//===----------------------------------------------------------------------===//
758+
759+
OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
760+
// x == x -> true
761+
if (getOperand1() == getOperand2()) {
762+
auto type = getType();
763+
if (isa<IntegerType>(type)) {
764+
return BoolAttr::get(getContext(), true);
765+
}
766+
if (isa<VectorType>(type)) {
767+
auto vtType = cast<ShapedType>(type);
768+
auto element = BoolAttr::get(getContext(), true);
769+
return DenseElementsAttr::get(vtType, element);
770+
}
771+
}
772+
773+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
774+
[](const APInt &a, const APInt &b) {
775+
APInt zero = APInt::getZero(1);
776+
return a == b ? (zero + 1) : zero;
777+
});
778+
}
779+
780+
//===----------------------------------------------------------------------===//
781+
// spirv.INotEqualOp
782+
//===----------------------------------------------------------------------===//
783+
784+
OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
785+
// x == x -> false
786+
if (getOperand1() == getOperand2()) {
787+
auto type = getType();
788+
if (isa<IntegerType>(type)) {
789+
return BoolAttr::get(getContext(), false);
790+
}
791+
if (isa<VectorType>(type)) {
792+
auto vtType = cast<ShapedType>(type);
793+
auto element = BoolAttr::get(getContext(), false);
794+
return DenseElementsAttr::get(vtType, element);
795+
}
796+
}
797+
798+
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
799+
[](const APInt &a, const APInt &b) {
800+
APInt zero = APInt::getZero(1);
801+
return a == b ? zero : (zero + 1);
802+
});
803+
}
804+
712805
//===----------------------------------------------------------------------===//
713806
// spirv.ShiftLeftLogical
714807
//===----------------------------------------------------------------------===//

mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
// CHECK-LABEL: @logical_equal_scalar
88
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
99
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
10-
%0 = spirv.LogicalEqual %arg0, %arg0 : i1
10+
%0 = spirv.LogicalEqual %arg0, %arg1 : i1
1111
spirv.Return
1212
}
1313

1414
// CHECK-LABEL: @logical_equal_vector
1515
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
1616
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
17-
%0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
17+
%0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
1818
spirv.Return
1919
}
2020

@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
2525
// CHECK-LABEL: @logical_not_equal_scalar
2626
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
2727
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
28-
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
28+
%0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
2929
spirv.Return
3030
}
3131

3232
// CHECK-LABEL: @logical_not_equal_vector
3333
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
3434
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
35-
%0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
35+
%0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
3636
spirv.Return
3737
}
3838

@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
6363
// CHECK-LABEL: @logical_and_scalar
6464
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
6565
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
66-
%0 = spirv.LogicalAnd %arg0, %arg0 : i1
66+
%0 = spirv.LogicalAnd %arg0, %arg1 : i1
6767
spirv.Return
6868
}
6969

7070
// CHECK-LABEL: @logical_and_vector
7171
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
7272
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
73-
%0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
73+
%0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
7474
spirv.Return
7575
}
7676

@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
8181
// CHECK-LABEL: @logical_or_scalar
8282
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
8383
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
84-
%0 = spirv.LogicalOr %arg0, %arg0 : i1
84+
%0 = spirv.LogicalOr %arg0, %arg1 : i1
8585
spirv.Return
8686
}
8787

8888
// CHECK-LABEL: @logical_or_vector
8989
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
9090
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
91-
%0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
91+
%0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
9292
spirv.Return
9393
}

0 commit comments

Comments
 (0)