Skip to content

[mlir][spirv] Add folding for SNegate, [Logical]Not #74992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2023

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Dec 10, 2023

Add missing constant propogation folder for SNegate, [Logical]Not.

Implement additional folding when !(!x) for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

Add missing constant propogation folder for SNegate, [Logical]Not.

Implement additional folding when !(!x) for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for llvm#70704
@inbelic inbelic force-pushed the inbelic/spirv-folding-negate-ops branch from 174035d to e2e231d Compare December 20, 2023 08:16
@inbelic inbelic marked this pull request as ready for review December 20, 2023 17:03
@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Finn Plummer (inbelic)

Changes

Add missing constant propogation folder for SNegate, [Logical]Not.

Implement additional folding when !(!x) for all ops.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704


Full diff: https://github.com/llvm/llvm-project/pull/74992.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td (+2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td (+1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+55)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+116)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 51124e141c6d46..22d5afcd773817 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -582,6 +582,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
     %3 = spirv.SNegate %2 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index b460c8e68aa0c6..38639a175ab4db 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -462,6 +462,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
     %3 = spirv.Not %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 #endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 47887ffb474f00..260d24b5502577 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -528,6 +528,7 @@ def SPIRV_LogicalNotOp : SPIRV_LogicalUnaryOp<"LogicalNot",
   }];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9de1707dfca465..fe334d50b6faaa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -643,6 +643,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
   return div0 ? Attribute() : res;
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+  // -(-x) = 0 - (0 - x) = x
+  auto op = getOperand();
+  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+    return negateOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer subtract of Operand from zero.
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        APInt zero = APInt::getZero(a.getBitWidth());
+        return zero - a;
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NotOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
+  // !(!x) = x
+  auto op = getOperand();
+  if (auto notOp = op.getDefiningOp<spirv::NotOp>())
+    return notOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Complement the bits of Operand.
+  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
+    a.flipAllBits();
+    return a;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
@@ -681,6 +720,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
 // spirv.LogicalNot
 //===----------------------------------------------------------------------===//
 
+OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
+  // !(!x) = x
+  auto op = getOperand();
+  if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
+    return notOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Complement the bits of Operand.
+  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
+                                       [](const APInt &a) {
+                                         APInt zero = APInt::getZero(1);
+                                         return a == 1 ? zero : (zero + 1);
+                                       });
+}
+
 void spirv::LogicalNotOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 29bea91ce461d9..7da2cf5be4e007 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1006,6 +1006,90 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+  %0 = spirv.SNegate %arg0 : i32
+  %1 = spirv.SNegate %0 : i32
+
+  // CHECK: return %[[ARG]] : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SNegate %c0 : i32
+  %1 = spirv.SNegate %c3 : i32
+  %2 = spirv.SNegate %cn3 : i32
+
+  // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_snegate
+func.func @const_fold_vector_snegate() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[0, 3, -3]>
+  %cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32>
+  %0 = spirv.SNegate %cv : vector<3xi32>
+  return %0  : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Not
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @not_twice(%arg0 : i32) -> i32 {
+  %0 = spirv.Not %arg0 : i32
+  %1 = spirv.Not %0 : i32
+
+  // CHECK: return %[[ARG]] : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_not
+func.func @const_fold_scalar_not() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32
+  // CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32
+  %0 = spirv.Not %c0 : i32
+  %1 = spirv.Not %c3 : i32
+  %2 = spirv.Not %cn3 : i32
+
+  // CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_not
+func.func @const_fold_vector_not() -> vector<3xi32> {
+  %cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[0, 3, -3]>
+  %0 = spirv.Not %cv : vector<3xi32>
+
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
@@ -1040,6 +1124,38 @@ func.func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<
 // spirv.LogicalNot
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @logical_not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+func.func @logical_not_twice(%arg0 : i1) -> i1 {
+  %0 = spirv.LogicalNot %arg0 : i1
+  %1 = spirv.LogicalNot %0 : i1
+
+  // CHECK: return %[[ARG]] : i1
+  return %1 : i1
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not
+func.func @const_fold_scalar_logical_not() -> i1 {
+  %true = spirv.Constant true
+
+  // CHECK: spirv.Constant false
+  %0 = spirv.LogicalNot %true : i1
+
+  return %0 : i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not
+func.func @const_fold_vector_logical_not() -> vector<2xi1> {
+  %cv = spirv.Constant dense<[true, false]> : vector<2xi1>
+
+  // CHECK: spirv.Constant dense<[false, true]>
+  %0 = spirv.LogicalNot %cv : vector<2xi1>
+
+  return %0 : vector<2xi1>
+}
+
+// -----
+
 func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
   // CHECK: %[[RESULT:.*]] = spirv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
   // CHECK-NEXT: spirv.ReturnValue %[[RESULT]] : vector<3xi1>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, I think it would be worth adding one more test for though.

- add testcase to demonstrate SNegate behaviour for INT_MIN
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@inbelic inbelic merged commit 88151dd into llvm:main Dec 21, 2023
@inbelic inbelic deleted the inbelic/spirv-folding-negate-ops branch August 2, 2024 19:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants