Skip to content

Commit 88151dd

Browse files
authored
[mlir][spirv] Add folding for SNegate, [Logical]Not (#74992)
Add missing constant propogation folder for SNegate, [Logical]Not. Implement additional folding when !(!x) for all ops. This helps for readability of lowered code into SPIR-V. Part of work for #70704
1 parent 537b2aa commit 88151dd

File tree

5 files changed

+188
-0
lines changed

5 files changed

+188
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
582582
%3 = spirv.SNegate %2 : vector<4xi32>
583583
```
584584
}];
585+
586+
let hasFolder = 1;
585587
}
586588

587589
// -----

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
462462
%3 = spirv.Not %1 : vector<4xi32>
463463
```
464464
}];
465+
466+
let hasFolder = 1;
465467
}
466468

467469
#endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def SPIRV_LogicalNotOp : SPIRV_LogicalUnaryOp<"LogicalNot",
534534
}];
535535

536536
let hasCanonicalizer = 1;
537+
let hasFolder = 1;
537538
}
538539

539540
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
643643
return div0 ? Attribute() : res;
644644
}
645645

646+
//===----------------------------------------------------------------------===//
647+
// spirv.SNegate
648+
//===----------------------------------------------------------------------===//
649+
650+
OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
651+
// -(-x) = 0 - (0 - x) = x
652+
auto op = getOperand();
653+
if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
654+
return negateOp->getOperand(0);
655+
656+
// According to the SPIR-V spec:
657+
//
658+
// Signed-integer subtract of Operand from zero.
659+
return constFoldUnaryOp<IntegerAttr>(
660+
adaptor.getOperands(), [](const APInt &a) {
661+
APInt zero = APInt::getZero(a.getBitWidth());
662+
return zero - a;
663+
});
664+
}
665+
666+
//===----------------------------------------------------------------------===//
667+
// spirv.NotOp
668+
//===----------------------------------------------------------------------===//
669+
670+
OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
671+
// !(!x) = x
672+
auto op = getOperand();
673+
if (auto notOp = op.getDefiningOp<spirv::NotOp>())
674+
return notOp->getOperand(0);
675+
676+
// According to the SPIR-V spec:
677+
//
678+
// Complement the bits of Operand.
679+
return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
680+
a.flipAllBits();
681+
return a;
682+
});
683+
}
684+
646685
//===----------------------------------------------------------------------===//
647686
// spirv.LogicalAnd
648687
//===----------------------------------------------------------------------===//
@@ -714,6 +753,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
714753
// spirv.LogicalNot
715754
//===----------------------------------------------------------------------===//
716755

756+
OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
757+
// !(!x) = x
758+
auto op = getOperand();
759+
if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
760+
return notOp->getOperand(0);
761+
762+
// According to the SPIR-V spec:
763+
//
764+
// Complement the bits of Operand.
765+
return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
766+
[](const APInt &a) {
767+
APInt zero = APInt::getZero(1);
768+
return a == 1 ? zero : (zero + 1);
769+
});
770+
}
771+
717772
void spirv::LogicalNotOp::getCanonicalizationPatterns(
718773
RewritePatternSet &results, MLIRContext *context) {
719774
results

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,102 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
10061006

10071007
// -----
10081008

1009+
//===----------------------------------------------------------------------===//
1010+
// spirv.SNegate
1011+
//===----------------------------------------------------------------------===//
1012+
1013+
// CHECK-LABEL: @snegate_twice
1014+
// CHECK-SAME: (%[[ARG:.*]]: i32)
1015+
func.func @snegate_twice(%arg0 : i32) -> i32 {
1016+
%0 = spirv.SNegate %arg0 : i32
1017+
%1 = spirv.SNegate %0 : i32
1018+
1019+
// CHECK: return %[[ARG]] : i32
1020+
return %1 : i32
1021+
}
1022+
1023+
// CHECK-LABEL: @snegate_min
1024+
func.func @snegate_min() -> (i8, i8) {
1025+
// CHECK: %[[MIN:.*]] = spirv.Constant -128 : i8
1026+
%cmin = spirv.Constant -128 : i8
1027+
1028+
%0 = spirv.SNegate %cmin : i8
1029+
%1 = spirv.SNegate %0 : i8
1030+
1031+
// CHECK: return %[[MIN]], %[[MIN]]
1032+
return %0, %1 : i8, i8
1033+
}
1034+
1035+
// CHECK-LABEL: @const_fold_scalar_snegate
1036+
func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
1037+
%c0 = spirv.Constant 0 : i32
1038+
%c3 = spirv.Constant 3 : i32
1039+
%cn3 = spirv.Constant -3 : i32
1040+
1041+
// CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
1042+
// CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
1043+
// CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
1044+
%0 = spirv.SNegate %c0 : i32
1045+
%1 = spirv.SNegate %c3 : i32
1046+
%2 = spirv.SNegate %cn3 : i32
1047+
1048+
// CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
1049+
return %0, %1, %2 : i32, i32, i32
1050+
}
1051+
1052+
// CHECK-LABEL: @const_fold_vector_snegate
1053+
func.func @const_fold_vector_snegate() -> vector<3xi32> {
1054+
// CHECK: spirv.Constant dense<[0, 3, -3]>
1055+
%cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32>
1056+
%0 = spirv.SNegate %cv : vector<3xi32>
1057+
return %0 : vector<3xi32>
1058+
}
1059+
1060+
// -----
1061+
1062+
//===----------------------------------------------------------------------===//
1063+
// spirv.Not
1064+
//===----------------------------------------------------------------------===//
1065+
1066+
// CHECK-LABEL: @not_twice
1067+
// CHECK-SAME: (%[[ARG:.*]]: i32)
1068+
func.func @not_twice(%arg0 : i32) -> i32 {
1069+
%0 = spirv.Not %arg0 : i32
1070+
%1 = spirv.Not %0 : i32
1071+
1072+
// CHECK: return %[[ARG]] : i32
1073+
return %1 : i32
1074+
}
1075+
1076+
// CHECK-LABEL: @const_fold_scalar_not
1077+
func.func @const_fold_scalar_not() -> (i32, i32, i32) {
1078+
%c0 = spirv.Constant 0 : i32
1079+
%c3 = spirv.Constant 3 : i32
1080+
%cn3 = spirv.Constant -3 : i32
1081+
1082+
// CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
1083+
// CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32
1084+
// CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32
1085+
%0 = spirv.Not %c0 : i32
1086+
%1 = spirv.Not %c3 : i32
1087+
%2 = spirv.Not %cn3 : i32
1088+
1089+
// CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]]
1090+
return %0, %1, %2 : i32, i32, i32
1091+
}
1092+
1093+
// CHECK-LABEL: @const_fold_vector_not
1094+
func.func @const_fold_vector_not() -> vector<3xi32> {
1095+
%cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
1096+
1097+
// CHECK: spirv.Constant dense<[0, 3, -3]>
1098+
%0 = spirv.Not %cv : vector<3xi32>
1099+
1100+
return %0 : vector<3xi32>
1101+
}
1102+
1103+
// -----
1104+
10091105
//===----------------------------------------------------------------------===//
10101106
// spirv.LogicalAnd
10111107
//===----------------------------------------------------------------------===//
@@ -1040,6 +1136,38 @@ func.func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<
10401136
// spirv.LogicalNot
10411137
//===----------------------------------------------------------------------===//
10421138

1139+
// CHECK-LABEL: @logical_not_twice
1140+
// CHECK-SAME: (%[[ARG:.*]]: i1)
1141+
func.func @logical_not_twice(%arg0 : i1) -> i1 {
1142+
%0 = spirv.LogicalNot %arg0 : i1
1143+
%1 = spirv.LogicalNot %0 : i1
1144+
1145+
// CHECK: return %[[ARG]] : i1
1146+
return %1 : i1
1147+
}
1148+
1149+
// CHECK-LABEL: @const_fold_scalar_logical_not
1150+
func.func @const_fold_scalar_logical_not() -> i1 {
1151+
%true = spirv.Constant true
1152+
1153+
// CHECK: spirv.Constant false
1154+
%0 = spirv.LogicalNot %true : i1
1155+
1156+
return %0 : i1
1157+
}
1158+
1159+
// CHECK-LABEL: @const_fold_vector_logical_not
1160+
func.func @const_fold_vector_logical_not() -> vector<2xi1> {
1161+
%cv = spirv.Constant dense<[true, false]> : vector<2xi1>
1162+
1163+
// CHECK: spirv.Constant dense<[false, true]>
1164+
%0 = spirv.LogicalNot %cv : vector<2xi1>
1165+
1166+
return %0 : vector<2xi1>
1167+
}
1168+
1169+
// -----
1170+
10431171
func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
10441172
// CHECK: %[[RESULT:.*]] = spirv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
10451173
// CHECK-NEXT: spirv.ReturnValue %[[RESULT]] : vector<3xi1>

0 commit comments

Comments
 (0)