Skip to content

[mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended #73340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 29, 2023

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Nov 24, 2023

Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

Add missing constant propogation folder for IAddCarry and
[S|U]MulExtended. Due to currently missing constant value for
spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs
is 1 for UMulExt.

This helps for readability of lowered code into SPIRV.

Part of work for llvm#70704
@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-mlir

Author: Finn Plummer (inbelic)

Changes

Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704


Full diff: https://github.com/llvm/llvm-project/pull/73340.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+190)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..951cfe4feb2e63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..cefcbfd87cbdd1c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..16215e21b369584 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant -3
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 1
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+  // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant -1
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 1 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant 4
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//

@llvmbot
Copy link
Member

llvmbot commented Nov 24, 2023

@llvm/pr-subscribers-mlir-spirv

Author: Finn Plummer (inbelic)

Changes

Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns.

Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704


Full diff: https://github.com/llvm/llvm-project/pull/73340.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+6)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+190)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+148)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..951cfe4feb2e63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..cefcbfd87cbdd1c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..16215e21b369584 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant -3
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 1
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+  // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant -1
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 1 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant 4
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//

LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto operands = op.getOperands();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do:

Value lhs = op.getOperand1();
Value rhs = op.getOperand2();

If this is not a very intuitive naming, I can add extra getters op.getLhs() and op.getRhs() -- WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree this looks better. I think it is okay to be done locally rather than having the getters.

Comment on lines 163 to 173
if (!adds)
return failure();

auto carrys = constFoldBinaryOp<IntegerAttr>(
ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
APInt zero = APInt::getZero(a.getBitWidth());
return a.ult(b) ? (zero + 1) : zero;
});

if (!carrys)
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that instead of using CompositeConstruct we could also create two insertions into the struct -- one for the sum, a separate one for the carry. This way we don't have to be able to fold both to simplify the result.

For example when we see something like this:

%x = spirv.UConvert %a : i8 to i32
%y = spirv.UConvert %b : i8 to i32
%res = spirv.IAddCarry %a, %b : i32

Here we can statically tell that the overflow bit is 0, but can't necessarily decide that the sum component is.

This may be a bit obscure, but maybe we could add it as a TODO comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that if it can allow for more greedy folding then we should use it. We also have the -spirv-rewrite-inserts pass that can convert from a CompositeInsert chain to the CompositeConstruct in the cases when desired.

The most recent commit has a potential implementation of switching to this.

- improve readability with lhs/rhs instead of operands[0]/[1]
- use stack array instead of llvm::SmallVector
- increase strictness of tests to ensure proper CompositeConstruct and
  return order
Copy link

github-actions bot commented Nov 29, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

- commit to demonstrate how we could potentially use CompositeInsert
  instead of CompositeConstruct
@inbelic inbelic force-pushed the inbelic/spirv-folding-ext-ops branch from 6a51740 to a0cb359 Compare November 29, 2023 17:17
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@kuhar kuhar changed the title [mlir][spirv] Add folding for IAddCarry/[S|U]MulExtended [mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended Nov 29, 2023
@kuhar kuhar merged commit 14028ec into llvm:main Nov 29, 2023
@kuhar
Copy link
Member

kuhar commented Nov 29, 2023

BTW @inbelic, with this many PRs in you should be able to request write access if you haven't done so already. The instructions are here: https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access

@inbelic inbelic deleted the inbelic/spirv-folding-ext-ops branch March 21, 2024 15:50
@inbelic inbelic restored the inbelic/spirv-folding-ext-ops branch March 21, 2024 15:50
@inbelic inbelic deleted the inbelic/spirv-folding-ext-ops branch March 21, 2024 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants