Skip to content

Commit 8cbb8ac

Browse files
authored
[mlir][spirv] Add folding for SelectOp (#85430)
Add missing constant propogation folder for spirv.Select Implement additional folding when both selections are equivalent or the condition is a constant Scalar/SplatVector. Allows for constant folding in the IndexToSPIRV pass. Part of work #70704
1 parent 2377b97 commit 8cbb8ac

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,8 @@ def SPIRV_SelectOp : SPIRV_Op<"Select",
800800
// These ops require dynamic availability specification based on operand and
801801
// result types.
802802
bit autogenAvailability = 0;
803+
804+
let hasFolder = 1;
803805
}
804806

805807
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,49 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
797797
return Attribute();
798798
}
799799

800+
//===----------------------------------------------------------------------===//
801+
// spirv.SelectOp
802+
//===----------------------------------------------------------------------===//
803+
804+
OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
805+
// spirv.Select _ x x -> x
806+
Value trueVals = getTrueValue();
807+
Value falseVals = getFalseValue();
808+
if (trueVals == falseVals)
809+
return trueVals;
810+
811+
ArrayRef<Attribute> operands = adaptor.getOperands();
812+
813+
// spirv.Select true x y -> x
814+
// spirv.Select false x y -> y
815+
if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
816+
return *boolAttr ? trueVals : falseVals;
817+
818+
// Check that all the operands are constant
819+
if (!operands[0] || !operands[1] || !operands[2])
820+
return Attribute();
821+
822+
// Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
823+
// the scalar case. Hence, we are only required to consider the case of
824+
// DenseElementsAttr in foldSelectOp.
825+
auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
826+
auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
827+
auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
828+
if (!condAttrs || !trueAttrs || !falseAttrs)
829+
return Attribute();
830+
831+
auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
832+
auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
833+
falseAttrs.getValues<Attribute>());
834+
for (auto [result, cond, falseRes] : iters) {
835+
if (!cond.getValue())
836+
result = falseRes;
837+
}
838+
839+
auto resultType = trueAttrs.getType();
840+
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
841+
}
842+
800843
//===----------------------------------------------------------------------===//
801844
// spirv.IEqualOp
802845
//===----------------------------------------------------------------------===//

mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ spirv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" {
4343
//===----------------------------------------------------------------------===//
4444

4545
// CHECK-LABEL: @select_scalar
46-
spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: f32) "None" {
46+
spirv.func @select_scalar(%arg0: i1, %arg1: vector<3xi32>, %arg2: vector<3xi32>, %arg3: f32, %arg4: f32) "None" {
4747
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, vector<3xi32>
48-
%0 = spirv.Select %arg0, %arg1, %arg1 : i1, vector<3xi32>
48+
%0 = spirv.Select %arg0, %arg1, %arg2 : i1, vector<3xi32>
4949
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : i1, f32
50-
%1 = spirv.Select %arg0, %arg2, %arg2 : i1, f32
50+
%1 = spirv.Select %arg0, %arg3, %arg4 : i1, f32
5151
spirv.Return
5252
}
5353

5454
// CHECK-LABEL: @select_vector
55-
spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>) "None" {
55+
spirv.func @select_vector(%arg0: vector<2xi1>, %arg1: vector<2xi32>, %arg2: vector<2xi32>) "None" {
5656
// CHECK: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : vector<2xi1>, vector<2xi32>
57-
%0 = spirv.Select %arg0, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
57+
%0 = spirv.Select %arg0, %arg1, %arg2 : vector<2xi1>, vector<2xi32>
5858
spirv.Return
5959
}
6060

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1346,6 +1346,52 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
13461346

13471347
// -----
13481348

1349+
//===----------------------------------------------------------------------===//
1350+
// spirv.Select
1351+
//===----------------------------------------------------------------------===//
1352+
1353+
// CHECK-LABEL: @convert_select_scalar
1354+
// CHECK-SAME: %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32
1355+
func.func @convert_select_scalar(%arg1: i32, %arg2: i32) -> (i32, i32) {
1356+
%true = spirv.Constant true
1357+
%false = spirv.Constant false
1358+
%0 = spirv.Select %true, %arg1, %arg2 : i1, i32
1359+
%1 = spirv.Select %false, %arg1, %arg2 : i1, i32
1360+
1361+
// CHECK: return %[[ARG1]], %[[ARG2]]
1362+
return %0, %1 : i32, i32
1363+
}
1364+
1365+
// CHECK-LABEL: @convert_select_vector
1366+
// CHECK-SAME: %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>
1367+
func.func @convert_select_vector(%arg1: vector<3xi32>, %arg2: vector<3xi32>) -> (vector<3xi32>, vector<3xi32>) {
1368+
%true = spirv.Constant dense<true> : vector<3xi1>
1369+
%false = spirv.Constant dense<false> : vector<3xi1>
1370+
%0 = spirv.Select %true, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
1371+
%1 = spirv.Select %false, %arg1, %arg2 : vector<3xi1>, vector<3xi32>
1372+
1373+
// CHECK: return %[[ARG1]], %[[ARG2]]
1374+
return %0, %1: vector<3xi32>, vector<3xi32>
1375+
}
1376+
1377+
// CHECK-LABEL: @convert_select_vector_extra
1378+
// CHECK-SAME: %[[CONDITIONS:.+]]: vector<2xi1>, %[[ARG1:.+]]: vector<2xi32>
1379+
func.func @convert_select_vector_extra(%conditions: vector<2xi1>, %arg1: vector<2xi32>) -> (vector<2xi32>, vector<2xi32>) {
1380+
%true_false = spirv.Constant dense<[true, false]> : vector<2xi1>
1381+
%cvec_1 = spirv.Constant dense<[42, -132]> : vector<2xi32>
1382+
%cvec_2 = spirv.Constant dense<[0, 42]> : vector<2xi32>
1383+
1384+
// CHECK: %[[RES:.+]] = spirv.Constant dense<42>
1385+
%0 = spirv.Select %true_false, %cvec_1, %cvec_2: vector<2xi1>, vector<2xi32>
1386+
1387+
%1 = spirv.Select %conditions, %arg1, %arg1 : vector<2xi1>, vector<2xi32>
1388+
1389+
// CHECK: return %[[RES]], %[[ARG1]]
1390+
return %0, %1: vector<2xi32>, vector<2xi32>
1391+
}
1392+
1393+
// -----
1394+
13491395
//===----------------------------------------------------------------------===//
13501396
// spirv.IEqual
13511397
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)