-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add canon patterns for IAddCarry/[S|U]MulExtended #73340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIRV. Part of work for llvm#70704
@llvm/pr-subscribers-mlir Author: Finn Plummer (inbelic) ChangesAdd missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIR-V. Part of work for #70704 Full diff: https://github.com/llvm/llvm-project/pull/73340.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..951cfe4feb2e63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..cefcbfd87cbdd1c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ if (!adds)
+ return failure();
+
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? (zero + 1) : zero;
+ });
+
+ if (!carrys)
+ return failure();
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(operands[1], m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(zero);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ constituents.push_back(lowBitsVal);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ constituents.push_back(highBitsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // umulextended (x, 1) = <x, 0>
+ if (matchPattern(operands[1], m_One())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(operands[0]);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..16215e21b369584 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant -3
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 1
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+ // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+ // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant -1
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 1 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant 4
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
|
@llvm/pr-subscribers-mlir-spirv Author: Finn Plummer (inbelic) ChangesAdd missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns. Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt. This helps for readability of lowered code into SPIR-V. Part of work for #70704 Full diff: https://github.com/llvm/llvm-project/pull/73340.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..951cfe4feb2e63e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
%2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -607,6 +609,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
%2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
@@ -742,6 +746,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
%2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
```
}];
+
+ let hasCanonicalizer = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..cefcbfd87cbdd1c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -115,6 +115,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
results.add<CombineChainedAccessChain>(context);
}
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // iaddcarry (x, 0) = <0, x>
+ if (matchPattern(operands[1], m_Zero())) {
+ constituents.push_back(operands[1]);
+ constituents.push_back(operands[0]);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits (full component width) of
+ // the addition.
+ //
+ // Member 1 of the result gets the high-order (carry) bit of the result of
+ // the addition. That is, it gets the value 1 if the addition overflowed
+ // the component width, and 0 otherwise.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto adds = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+ if (!adds)
+ return failure();
+
+ auto carrys = constFoldBinaryOp<IntegerAttr>(
+ ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return a.ult(b) ? (zero + 1) : zero;
+ });
+
+ if (!carrys)
+ return failure();
+
+ Value addsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+ constituents.push_back(addsVal);
+
+ Value carrysVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+ constituents.push_back(carrysVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+ using OpRewritePattern<MulOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MulOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // [su]mulextended (x, 0) = <0, 0>
+ if (matchPattern(operands[1], m_Zero())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(zero);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ // According to the SPIR-V spec:
+ //
+ // Result Type must be from OpTypeStruct. The struct must have two
+ // members...
+ //
+ // Member 0 of the result gets the low-order bits of the multiplication.
+ //
+ // Member 1 of the result gets the high-order bits of the multiplication.
+ Attribute lhs;
+ Attribute rhs;
+ if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+ !matchPattern(operands[1], m_Constant(&rhs)))
+ return failure();
+
+ auto lowBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+ if (!lowBits)
+ return failure();
+
+ auto highBits = constFoldBinaryOp<IntegerAttr>(
+ {lhs, rhs}, [](const APInt &a, const APInt &b) {
+ unsigned bitWidth = a.getBitWidth();
+ APInt c;
+ if (IsSigned) {
+ c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+ } else {
+ c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+ }
+ return c.extractBits(bitWidth, bitWidth); // Extract high result
+ });
+
+ if (!highBits)
+ return failure();
+
+ Value lowBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+ constituents.push_back(lowBitsVal);
+
+ Value highBitsVal =
+ rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+ constituents.push_back(highBitsVal);
+
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto operands = op.getOperands();
+
+ SmallVector<Value> constituents;
+ Type constituentType = operands[0].getType();
+
+ // umulextended (x, 1) = <x, 0>
+ if (matchPattern(operands[1], m_One())) {
+ Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+ constituents.push_back(operands[0]);
+ constituents.push_back(zero);
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+ constituents);
+ return success();
+ }
+
+ return failure();
+ }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+ RewritePatternSet &patterns, MLIRContext *context) {
+ patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
//===----------------------------------------------------------------------===//
// spirv.UMod
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a..16215e21b369584 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.Constant -3
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 1
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+ // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+ // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.IMul
//===----------------------------------------------------------------------===//
@@ -400,6 +446,108 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant -1
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant 0
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 0 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+ %c0 = spirv.Constant 1 : i32
+
+ // CHECK: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+ return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+ %c5 = spirv.Constant 5 : i32
+ %cn5 = spirv.Constant -5 : i32
+ %cn8 = spirv.Constant -8 : i32
+
+ // CHECK-DAG: spirv.Constant 40
+ // CHECK-DAG: spirv.Constant -13
+ // CHECK-DAG: spirv.CompositeConstruct
+ // CHECK-DAG: spirv.Constant -40
+ // CHECK-DAG: spirv.Constant 4
+ // CHECK-DAG: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+ %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+ return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+ %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+ %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+ // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+ // CHECK-NEXT: spirv.CompositeConstruct
+ %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+ return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
//===----------------------------------------------------------------------===//
// spirv.ISub
//===----------------------------------------------------------------------===//
|
LogicalResult matchAndRewrite(spirv::IAddCarryOp op, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
auto operands = op.getOperands(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do:
Value lhs = op.getOperand1();
Value rhs = op.getOperand2();
If this is not a very intuitive naming, I can add extra getters op.getLhs()
and op.getRhs()
-- WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree this looks better. I think it is okay to be done locally rather than having the getters.
if (!adds) | ||
return failure(); | ||
|
||
auto carrys = constFoldBinaryOp<IntegerAttr>( | ||
ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) { | ||
APInt zero = APInt::getZero(a.getBitWidth()); | ||
return a.ult(b) ? (zero + 1) : zero; | ||
}); | ||
|
||
if (!carrys) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that instead of using CompositeConstruct
we could also create two insertions into the struct -- one for the sum, a separate one for the carry. This way we don't have to be able to fold both to simplify the result.
For example when we see something like this:
%x = spirv.UConvert %a : i8 to i32
%y = spirv.UConvert %b : i8 to i32
%res = spirv.IAddCarry %a, %b : i32
Here we can statically tell that the overflow bit is 0
, but can't necessarily decide that the sum component is.
This may be a bit obscure, but maybe we could add it as a TODO comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that if it can allow for more greedy folding then we should use it. We also have the -spirv-rewrite-inserts
pass that can convert from a CompositeInsert
chain to the CompositeConstruct
in the cases when desired.
The most recent commit has a potential implementation of switching to this.
- improve readability with lhs/rhs instead of operands[0]/[1] - use stack array instead of llvm::SmallVector - increase strictness of tests to ensure proper CompositeConstruct and return order
✅ With the latest revision this PR passed the C/C++ code formatter. |
- commit to demonstrate how we could potentially use CompositeInsert instead of CompositeConstruct
6a51740
to
a0cb359
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
BTW @inbelic, with this many PRs in you should be able to request write access if you haven't done so already. The instructions are here: https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access |
Add missing constant propogation folder for IAddCarry and [S|U]MulExtended. Due to currently missing constant value for spirv.struct the folding is done using canonicalization patterns.
Implement additional folding when rhs is 0 for all ops and when rhs is 1 for UMulExt.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704