Skip to content

Commit a3c7d46

Browse files
authored
[mlir][spirv] Implement UMod canonicalization for vector constants (#141902)
Closes #63174. Implements this transformation pattern, which is currently only applied to scalars, for vectors: ``` %1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32 %2 = "spirv.UMod"(%1, %CONST_4) : (i32, i32) -> i32 ``` to ``` %1 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32 %2 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32 ``` Additionally fixes and issue where patterns like this: ``` %1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32 %2 = "spirv.UMod"(%1, %CONST_32) : (i32, i32) -> i32 ``` were incorrectly canonicalized to: ``` %1 = "spirv.UMod"(%0, %CONST_4) : (i32, i32) -> i32 %2 = "spirv.UMod"(%0, %CONST_32) : (i32, i32) -> i32 ``` which is incorrect since `(X % A) % B` == `(X % B)` IFF A is a multiple of B, i.e., B divides A.
1 parent 02f0f5c commit a3c7d46

File tree

2 files changed

+69
-15
lines changed

2 files changed

+69
-15
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
326326

327327
// The transformation is only applied if one divisor is a multiple of the other.
328328

329-
// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330329
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331330
using OpRewritePattern::OpRewritePattern;
332331

@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
336335
if (!prevUMod)
337336
return failure();
338337

339-
IntegerAttr prevValue;
340-
IntegerAttr currValue;
338+
TypedAttr prevValue;
339+
TypedAttr currValue;
341340
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342341
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343342
return failure();
344343

345-
APInt prevConstValue = prevValue.getValue();
346-
APInt currConstValue = currValue.getValue();
344+
// Ensure that previous divisor is a multiple of the current divisor. If
345+
// not, fail the transformation.
346+
bool isApplicable = false;
347+
if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
348+
auto currInt = cast<IntegerAttr>(currValue);
349+
isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350+
} else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
351+
auto currVec = cast<DenseElementsAttr>(currValue);
352+
isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues<APInt>(),
353+
currVec.getValues<APInt>()),
354+
[](const auto &pair) {
355+
auto &[prev, curr] = pair;
356+
return prev.urem(curr) == 0;
357+
});
358+
}
347359

348-
// Ensure that one divisor is a multiple of the other. If not, fail the
349-
// transformation.
350-
if (prevConstValue.urem(currConstValue) != 0 &&
351-
currConstValue.urem(prevConstValue) != 0)
360+
if (!isApplicable)
352361
return failure();
353362

354363
// The transformation is safe. Replace the existing UMod operation with a

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
967967
return %0, %1: i32, i32
968968
}
969969

970-
// CHECK-LABEL: @umod_fail_vector_fold
970+
// CHECK-LABEL: @umod_vector_fold
971971
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
972-
func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
972+
func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
973973
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
974974
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
975975
%const1 = spirv.Constant dense<32> : vector<4xi32>
976976
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
977-
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
978977
%const2 = spirv.Constant dense<4> : vector<4xi32>
979978
%1 = spirv.UMod %0, %const2 : vector<4xi32>
980-
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
979+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
980+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
981981
// CHECK: return %[[UMOD0]], %[[UMOD1]]
982982
return %0, %1: vector<4xi32>, vector<4xi32>
983983
}
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
996996
return %0, %1: i32, i32
997997
}
998998

999-
// CHECK-LABEL: @umod_fail_fold
999+
// CHECK-LABEL: @umod_fail_1_fold
10001000
// CHECK-SAME: (%[[ARG:.*]]: i32)
1001-
func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
1001+
func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
10021002
// CHECK: %[[CONST5:.*]] = spirv.Constant 5
10031003
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
10041004
%const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,51 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
10111011
return %0, %1: i32, i32
10121012
}
10131013

1014+
// CHECK-LABEL: @umod_fail_2_fold
1015+
// CHECK-SAME: (%[[ARG:.*]]: i32)
1016+
func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) {
1017+
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
1018+
// CHECK: %[[CONST4:.*]] = spirv.Constant 4
1019+
%const1 = spirv.Constant 4 : i32
1020+
%0 = spirv.UMod %arg0, %const1 : i32
1021+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
1022+
%const2 = spirv.Constant 32 : i32
1023+
%1 = spirv.UMod %0, %const2 : i32
1024+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
1025+
// CHECK: return %[[UMOD0]], %[[UMOD1]]
1026+
return %0, %1: i32, i32
1027+
}
1028+
1029+
// CHECK-LABEL: @umod_vector_fail_1_fold
1030+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
1031+
func.func @umod_vector_fail_1_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
1032+
// CHECK: %[[CONST9:.*]] = spirv.Constant dense<9> : vector<4xi32>
1033+
// CHECK: %[[CONST64:.*]] = spirv.Constant dense<64> : vector<4xi32>
1034+
%const1 = spirv.Constant dense<64> : vector<4xi32>
1035+
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
1036+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST64]]
1037+
%const2 = spirv.Constant dense<9> : vector<4xi32>
1038+
%1 = spirv.UMod %0, %const2 : vector<4xi32>
1039+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST9]]
1040+
// CHECK: return %[[UMOD0]], %[[UMOD1]]
1041+
return %0, %1: vector<4xi32>, vector<4xi32>
1042+
}
1043+
1044+
// CHECK-LABEL: @umod_vector_fail_2_fold
1045+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
1046+
func.func @umod_vector_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
1047+
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
1048+
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
1049+
%const1 = spirv.Constant dense<4> : vector<4xi32>
1050+
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
1051+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
1052+
%const2 = spirv.Constant dense<32> : vector<4xi32>
1053+
%1 = spirv.UMod %0, %const2 : vector<4xi32>
1054+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
1055+
// CHECK: return %[[UMOD0]], %[[UMOD1]]
1056+
return %0, %1: vector<4xi32>, vector<4xi32>
1057+
}
1058+
10141059
// -----
10151060

10161061
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)