-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add folding for SNegate, [Logical]Not #74992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add missing constant propogation folder for SNegate, [Logical]Not. Implement additional folding when !(!x) for all ops. This helps for readability of lowered code into SPIR-V. Part of work for llvm#70704
174035d
to
e2e231d
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Finn Plummer (inbelic) ChangesAdd missing constant propogation folder for SNegate, [Logical]Not. Implement additional folding when !(!x) for all ops. This helps for readability of lowered code into SPIR-V. Part of work for #70704 Full diff: https://github.com/llvm/llvm-project/pull/74992.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 51124e141c6d46..22d5afcd773817 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -582,6 +582,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
%3 = spirv.SNegate %2 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index b460c8e68aa0c6..38639a175ab4db 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -462,6 +462,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
%3 = spirv.Not %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
#endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 47887ffb474f00..260d24b5502577 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -528,6 +528,7 @@ def SPIRV_LogicalNotOp : SPIRV_LogicalUnaryOp<"LogicalNot",
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9de1707dfca465..fe334d50b6faaa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -643,6 +643,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
return div0 ? Attribute() : res;
}
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NotOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::NotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
+ a.flipAllBits();
+ return a;
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -681,6 +720,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
+OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a) {
+ APInt zero = APInt::getZero(1);
+ return a == 1 ? zero : (zero + 1);
+ });
+}
+
void spirv::LogicalNotOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 29bea91ce461d9..7da2cf5be4e007 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1006,6 +1006,90 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+ %0 = spirv.SNegate %arg0 : i32
+ %1 = spirv.SNegate %0 : i32
+
+ // CHECK: return %[[ARG]] : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SNegate %c0 : i32
+ %1 = spirv.SNegate %c3 : i32
+ %2 = spirv.SNegate %cn3 : i32
+
+ // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_snegate
+func.func @const_fold_vector_snegate() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[0, 3, -3]>
+ %cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32>
+ %0 = spirv.SNegate %cv : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Not
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @not_twice(%arg0 : i32) -> i32 {
+ %0 = spirv.Not %arg0 : i32
+ %1 = spirv.Not %0 : i32
+
+ // CHECK: return %[[ARG]] : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_not
+func.func @const_fold_scalar_not() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32
+ // CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32
+ %0 = spirv.Not %c0 : i32
+ %1 = spirv.Not %c3 : i32
+ %2 = spirv.Not %cn3 : i32
+
+ // CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_not
+func.func @const_fold_vector_not() -> vector<3xi32> {
+ %cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[0, 3, -3]>
+ %0 = spirv.Not %cv : vector<3xi32>
+
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -1040,6 +1124,38 @@ func.func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @logical_not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+func.func @logical_not_twice(%arg0 : i1) -> i1 {
+ %0 = spirv.LogicalNot %arg0 : i1
+ %1 = spirv.LogicalNot %0 : i1
+
+ // CHECK: return %[[ARG]] : i1
+ return %1 : i1
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not
+func.func @const_fold_scalar_logical_not() -> i1 {
+ %true = spirv.Constant true
+
+ // CHECK: spirv.Constant false
+ %0 = spirv.LogicalNot %true : i1
+
+ return %0 : i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not
+func.func @const_fold_vector_logical_not() -> vector<2xi1> {
+ %cv = spirv.Constant dense<[true, false]> : vector<2xi1>
+
+ // CHECK: spirv.Constant dense<[false, true]>
+ %0 = spirv.LogicalNot %cv : vector<2xi1>
+
+ return %0 : vector<2xi1>
+}
+
+// -----
+
func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spirv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: spirv.ReturnValue %[[RESULT]] : vector<3xi1>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, I think it would be worth adding one more test for though.
- add testcase to demonstrate SNegate behaviour for INT_MIN
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Add missing constant propogation folder for SNegate, [Logical]Not.
Implement additional folding when !(!x) for all ops.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704