Skip to content

Commit cc3b6f9

Browse files
authored
[mlir][spirv] Add folding for [S|U|GreaterThan[Equal] (#85434)
Add missing constant propogation folder for [S|U]GreaterThan[Equal]. Implement additional folding when the operands are equal for all ops. Allows for constant folding in the IndexToSPIRV pass. Part of work #70704
1 parent 93f9fb2 commit cc3b6f9

File tree

3 files changed

+268
-0
lines changed

3 files changed

+268
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ def SPIRV_SGreaterThanOp : SPIRV_LogicalBinaryOp<"SGreaterThan",
659659

660660
```
661661
}];
662+
663+
let hasFolder = 1;
662664
}
663665

664666
// -----
@@ -688,6 +690,8 @@ def SPIRV_SGreaterThanEqualOp : SPIRV_LogicalBinaryOp<"SGreaterThanEqual",
688690
%5 = spirv.SGreaterThanEqual %2, %3 : vector<4xi32>
689691
```
690692
}];
693+
694+
let hasFolder = 1;
691695
}
692696

693697
// -----
@@ -834,6 +838,8 @@ def SPIRV_UGreaterThanOp : SPIRV_LogicalBinaryOp<"UGreaterThan",
834838
%5 = spirv.UGreaterThan %2, %3 : vector<4xi32>
835839
```
836840
}];
841+
842+
let hasFolder = 1;
837843
}
838844

839845
// -----
@@ -863,6 +869,8 @@ def SPIRV_UGreaterThanEqualOp : SPIRV_LogicalBinaryOp<"UGreaterThanEqual",
863869
%5 = spirv.UGreaterThanEqual %2, %3 : vector<4xi32>
864870
```
865871
}];
872+
873+
let hasFolder = 1;
866874
}
867875

868876
// -----

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,90 @@ OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
880880
});
881881
}
882882

883+
//===----------------------------------------------------------------------===//
884+
// spirv.SGreaterThan
885+
//===----------------------------------------------------------------------===//
886+
887+
OpFoldResult
888+
spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
889+
// x == x -> false
890+
if (getOperand1() == getOperand2()) {
891+
auto falseAttr = BoolAttr::get(getContext(), false);
892+
if (isa<IntegerType>(getType()))
893+
return falseAttr;
894+
if (auto vecTy = dyn_cast<VectorType>(getType()))
895+
return SplatElementsAttr::get(vecTy, falseAttr);
896+
}
897+
898+
return constFoldBinaryOp<IntegerAttr>(
899+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
900+
return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
901+
});
902+
}
903+
904+
//===----------------------------------------------------------------------===//
905+
// spirv.SGreaterThanEqual
906+
//===----------------------------------------------------------------------===//
907+
908+
OpFoldResult spirv::SGreaterThanEqualOp::fold(
909+
spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
910+
// x == x -> true
911+
if (getOperand1() == getOperand2()) {
912+
auto trueAttr = BoolAttr::get(getContext(), true);
913+
if (isa<IntegerType>(getType()))
914+
return trueAttr;
915+
if (auto vecTy = dyn_cast<VectorType>(getType()))
916+
return SplatElementsAttr::get(vecTy, trueAttr);
917+
}
918+
919+
return constFoldBinaryOp<IntegerAttr>(
920+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
921+
return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
922+
});
923+
}
924+
925+
//===----------------------------------------------------------------------===//
926+
// spirv.UGreaterThan
927+
//===----------------------------------------------------------------------===//
928+
929+
OpFoldResult
930+
spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
931+
// x == x -> false
932+
if (getOperand1() == getOperand2()) {
933+
auto falseAttr = BoolAttr::get(getContext(), false);
934+
if (isa<IntegerType>(getType()))
935+
return falseAttr;
936+
if (auto vecTy = dyn_cast<VectorType>(getType()))
937+
return SplatElementsAttr::get(vecTy, falseAttr);
938+
}
939+
940+
return constFoldBinaryOp<IntegerAttr>(
941+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
942+
return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
943+
});
944+
}
945+
946+
//===----------------------------------------------------------------------===//
947+
// spirv.UGreaterThanEqual
948+
//===----------------------------------------------------------------------===//
949+
950+
OpFoldResult spirv::UGreaterThanEqualOp::fold(
951+
spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
952+
// x == x -> true
953+
if (getOperand1() == getOperand2()) {
954+
auto trueAttr = BoolAttr::get(getContext(), true);
955+
if (isa<IntegerType>(getType()))
956+
return trueAttr;
957+
if (auto vecTy = dyn_cast<VectorType>(getType()))
958+
return SplatElementsAttr::get(vecTy, trueAttr);
959+
}
960+
961+
return constFoldBinaryOp<IntegerAttr>(
962+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
963+
return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
964+
});
965+
}
966+
883967
//===----------------------------------------------------------------------===//
884968
// spirv.SLessThan
885969
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,182 @@ func.func @const_fold_vector_inotequal() -> vector<3xi1> {
14781478

14791479
// -----
14801480

1481+
//===----------------------------------------------------------------------===//
1482+
// spirv.SGreaterThan
1483+
//===----------------------------------------------------------------------===//
1484+
1485+
// CHECK-LABEL: @sgt_same
1486+
func.func @sgt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1487+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1488+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1489+
%0 = spirv.SGreaterThan %arg0, %arg0 : i32
1490+
%1 = spirv.SGreaterThan %arg1, %arg1 : vector<3xi32>
1491+
1492+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1493+
return %0, %1 : i1, vector<3xi1>
1494+
}
1495+
1496+
// CHECK-LABEL: @const_fold_scalar_sgt
1497+
func.func @const_fold_scalar_sgt() -> (i1, i1) {
1498+
%c4 = spirv.Constant 4 : i32
1499+
%c5 = spirv.Constant 5 : i32
1500+
%c6 = spirv.Constant 6 : i32
1501+
1502+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1503+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1504+
%0 = spirv.SGreaterThan %c5, %c6 : i32
1505+
%1 = spirv.SGreaterThan %c5, %c4 : i32
1506+
1507+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1508+
return %0, %1 : i1, i1
1509+
}
1510+
1511+
// CHECK-LABEL: @const_fold_vector_sgt
1512+
func.func @const_fold_vector_sgt() -> vector<3xi1> {
1513+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1514+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1515+
1516+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
1517+
%0 = spirv.SGreaterThan %cv0, %cv1 : vector<3xi32>
1518+
1519+
// CHECK: return %[[RET]]
1520+
return %0 : vector<3xi1>
1521+
}
1522+
1523+
// -----
1524+
1525+
//===----------------------------------------------------------------------===//
1526+
// spirv.SGreaterThanEqual
1527+
//===----------------------------------------------------------------------===//
1528+
1529+
// CHECK-LABEL: @sge_same
1530+
func.func @sge_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1531+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1532+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1533+
%0 = spirv.SGreaterThanEqual %arg0, %arg0 : i32
1534+
%1 = spirv.SGreaterThanEqual %arg1, %arg1 : vector<3xi32>
1535+
1536+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1537+
return %0, %1 : i1, vector<3xi1>
1538+
}
1539+
1540+
// CHECK-LABEL: @const_fold_scalar_sge
1541+
func.func @const_fold_scalar_sge() -> (i1, i1) {
1542+
%c4 = spirv.Constant 4 : i32
1543+
%c5 = spirv.Constant 5 : i32
1544+
%c6 = spirv.Constant 6 : i32
1545+
1546+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1547+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1548+
%0 = spirv.SGreaterThanEqual %c5, %c6 : i32
1549+
%1 = spirv.SGreaterThanEqual %c5, %c4 : i32
1550+
1551+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1552+
return %0, %1 : i1, i1
1553+
}
1554+
1555+
// CHECK-LABEL: @const_fold_vector_sge
1556+
func.func @const_fold_vector_sge() -> vector<3xi1> {
1557+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1558+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1559+
1560+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
1561+
%0 = spirv.SGreaterThanEqual %cv0, %cv1 : vector<3xi32>
1562+
1563+
// CHECK: return %[[RET]]
1564+
return %0 : vector<3xi1>
1565+
}
1566+
1567+
// -----
1568+
1569+
//===----------------------------------------------------------------------===//
1570+
// spirv.UGreaterThan
1571+
//===----------------------------------------------------------------------===//
1572+
1573+
// CHECK-LABEL: @ugt_same
1574+
func.func @ugt_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1575+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1576+
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
1577+
%0 = spirv.UGreaterThan %arg0, %arg0 : i32
1578+
%1 = spirv.UGreaterThan %arg1, %arg1 : vector<3xi32>
1579+
1580+
// CHECK: return %[[CFALSE]], %[[CVFALSE]]
1581+
return %0, %1 : i1, vector<3xi1>
1582+
}
1583+
1584+
// CHECK-LABEL: @const_fold_scalar_ugt
1585+
func.func @const_fold_scalar_ugt() -> (i1, i1) {
1586+
%c4 = spirv.Constant 4 : i32
1587+
%c5 = spirv.Constant 5 : i32
1588+
%cn6 = spirv.Constant -6 : i32
1589+
1590+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1591+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1592+
%0 = spirv.UGreaterThan %c5, %cn6 : i32
1593+
%1 = spirv.UGreaterThan %c5, %c4 : i32
1594+
1595+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1596+
return %0, %1 : i1, i1
1597+
}
1598+
1599+
// CHECK-LABEL: @const_fold_vector_ugt
1600+
func.func @const_fold_vector_ugt() -> vector<3xi1> {
1601+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1602+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1603+
1604+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
1605+
%0 = spirv.UGreaterThan %cv0, %cv1 : vector<3xi32>
1606+
1607+
// CHECK: return %[[RET]]
1608+
return %0 : vector<3xi1>
1609+
}
1610+
1611+
// -----
1612+
1613+
//===----------------------------------------------------------------------===//
1614+
// spirv.UGreaterThanEqual
1615+
//===----------------------------------------------------------------------===//
1616+
1617+
// CHECK-LABEL: @uge_same
1618+
func.func @uge_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
1619+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1620+
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
1621+
%0 = spirv.UGreaterThanEqual %arg0, %arg0 : i32
1622+
%1 = spirv.UGreaterThanEqual %arg1, %arg1 : vector<3xi32>
1623+
1624+
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
1625+
return %0, %1 : i1, vector<3xi1>
1626+
}
1627+
1628+
// CHECK-LABEL: @const_fold_scalar_uge
1629+
func.func @const_fold_scalar_uge() -> (i1, i1) {
1630+
%c4 = spirv.Constant 4 : i32
1631+
%c5 = spirv.Constant 5 : i32
1632+
%cn6 = spirv.Constant -6 : i32
1633+
1634+
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
1635+
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
1636+
%0 = spirv.UGreaterThanEqual %c5, %cn6 : i32
1637+
%1 = spirv.UGreaterThanEqual %c5, %c4 : i32
1638+
1639+
// CHECK: return %[[CFALSE]], %[[CTRUE]]
1640+
return %0, %1 : i1, i1
1641+
}
1642+
1643+
// CHECK-LABEL: @const_fold_vector_uge
1644+
func.func @const_fold_vector_uge() -> vector<3xi1> {
1645+
%cv0 = spirv.Constant dense<[-1, -4, 3]> : vector<3xi32>
1646+
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
1647+
1648+
// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
1649+
%0 = spirv.UGreaterThanEqual %cv0, %cv1 : vector<3xi32>
1650+
1651+
// CHECK: return %[[RET]]
1652+
return %0 : vector<3xi1>
1653+
}
1654+
1655+
// -----
1656+
14811657
//===----------------------------------------------------------------------===//
14821658
// spirv.SLessThan
14831659
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)