@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
1306
1306
1307
1307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1308
1308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1309
- "ForeachOp", "IterateOp"]>]> {
1309
+ "ForeachOp", "IterateOp", "CoIterateOp" ]>]> {
1310
1310
let summary = "Yield from sparse_tensor set-like operations";
1311
1311
let description = [{
1312
1312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",
1629
1629
1630
1630
let arguments = (ins AnySparseIterSpace:$iterSpace,
1631
1631
Variadic<AnyType>:$initArgs,
1632
- LevelSetAttr :$crdUsedLvls);
1632
+ I64BitSetAttr :$crdUsedLvls);
1633
1633
let results = (outs Variadic<AnyType>:$results);
1634
1634
let regions = (region SizedRegion<1>:$region);
1635
1635
1636
1636
let skipDefaultBuilders = 1;
1637
1637
let builders = [
1638
1638
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
1639
- OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet " :$crdUsedLvls)>
1639
+ OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet " :$crdUsedLvls)>
1640
1640
];
1641
1641
1642
1642
let extraClassDeclaration = [{
@@ -1669,6 +1669,127 @@ def IterateOp : SparseTensor_Op<"iterate",
1669
1669
let hasCustomAssemblyFormat = 1;
1670
1670
}
1671
1671
1672
+ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
1673
+ [AttrSizedOperandSegments,
1674
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
1675
+ RecursiveMemoryEffects]> {
1676
+ let summary = "Co-iterates over a set of sparse iteration spaces";
1677
+ let description = [{
1678
+ The `sparse_tensor.coiterate` operation represents a loop (nest) over
1679
+ a set of iteration spaces. The operation can have multiple regions,
1680
+ with each of them defining a case to compute a result at the current iterations.
1681
+ The case condition is defined solely based on the pattern of specified iterators.
1682
+ For example:
1683
+ ```mlir
1684
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
1685
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
1686
+ !sparse_tensor.iter_space<#COO, lvls = 0>)
1687
+ -> index
1688
+ case %it1, _ {
1689
+ // %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
1690
+ }
1691
+ case %it1, %it2 {
1692
+ // %coord is specifed in *BOTH* spaces %sp1 and %sp2.
1693
+ }
1694
+ ```
1695
+
1696
+ `sparse_tensor.coiterate` can also operate on loop-carried variables.
1697
+ It returns the final value for each loop-carried variable after loop termination.
1698
+ The initial values of the variables are passed as additional SSA operands
1699
+ to the iterator SSA value and used coordinate SSA values.
1700
+ Each operation region has variadic arguments for specified (used), one argument
1701
+ for each loop-carried variable, representing the value of the variable
1702
+ at the current iteration, followed by a list of arguments for iterators.
1703
+ The body region must contain exactly one block that terminates with
1704
+ `sparse_tensor.yield`.
1705
+
1706
+ The results of an `sparse_tensor.coiterate` hold the final values after
1707
+ the last iteration. If the `sparse_tensor.coiterate` defines any values,
1708
+ a yield must be explicitly present in every region defined in the operation.
1709
+ The number and types of the `sparse_tensor.coiterate` results must match
1710
+ the initial values in the iter_args binding and the yield operands.
1711
+
1712
+
1713
+ A `sparse_tensor.coiterate` example that does elementwise addition between two
1714
+ sparse vectors.
1715
+
1716
+
1717
+ ```mlir
1718
+ %ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
1719
+ : (!sparse_tensor.iter_space<#CSR, lvls = 0>,
1720
+ !sparse_tensor.iter_space<#CSR, lvls = 0>)
1721
+ -> tensor<?xindex, #CSR>
1722
+ case %it1, _ {
1723
+ // v = v1 + 0 = v1
1724
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
1725
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
1726
+ sparse_tensor.yield %yield
1727
+ }
1728
+ case _, %it2 {
1729
+ // v = v2 + 0 = v2
1730
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
1731
+ %yield = sparse_tensor.insert %v1 into %arg[%coord]
1732
+ sparse_tensor.yield %yield
1733
+ }
1734
+ case %it1, %it2 {
1735
+ // v = v1 + v2
1736
+ %v1 = sparse_tensor.extract_value %t1 at %it1 : index
1737
+ %v2 = sparse_tensor.extract_value %t2 at %it2 : index
1738
+ %v = arith.addi %v1, %v2 : index
1739
+ %yield = sparse_tensor.insert %v into %arg[%coord]
1740
+ sparse_tensor.yield %yield
1741
+ }
1742
+ ```
1743
+ }];
1744
+
1745
+ let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
1746
+ Variadic<AnyType>:$initArgs,
1747
+ I64BitSetAttr:$crdUsedLvls,
1748
+ I64BitSetArrayAttr:$cases);
1749
+ let results = (outs Variadic<AnyType>:$results);
1750
+ let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
1751
+
1752
+ let extraClassDeclaration = [{
1753
+ unsigned getSpaceDim() {
1754
+ return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
1755
+ getIterSpaces().front().getType())
1756
+ .getSpaceDim();
1757
+ }
1758
+ I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
1759
+ return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
1760
+ .getValue().getZExtValue());
1761
+ }
1762
+ auto getRegionDefinedSpaces() {
1763
+ return llvm::map_range(getCases().getValue(), [](Attribute attr) {
1764
+ return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
1765
+ });
1766
+ }
1767
+
1768
+ // The block arguments starts with referenced coordinates, follows by
1769
+ // user-provided iteration arguments and ends with iterators.
1770
+ Block::BlockArgListType getCrds(unsigned regionIdx) {
1771
+ return getRegion(regionIdx).getArguments()
1772
+ .take_front(getCrdUsedLvls().count());
1773
+ }
1774
+ unsigned getNumRegionIterArgs(unsigned regionIdx) {
1775
+ return getInitArgs().size();
1776
+ }
1777
+ Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
1778
+ return getRegion(regionIdx).getArguments()
1779
+ .slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
1780
+ }
1781
+ Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
1782
+ return getRegion(regionIdx).getArguments()
1783
+ .take_back(getRegionDefinedSpace(regionIdx).count());
1784
+ }
1785
+ ValueRange getYieldedValues(unsigned regionIdx);
1786
+ }];
1787
+
1788
+ let hasVerifier = 1;
1789
+ let hasRegionVerifier = 1;
1790
+ let hasCustomAssemblyFormat = 1;
1791
+ }
1792
+
1672
1793
//===----------------------------------------------------------------------===//
1673
1794
// Sparse Tensor Debugging and Test-Only Operations.
1674
1795
//===----------------------------------------------------------------------===//
0 commit comments