Skip to content

Commit 785a24f

Browse files
author
Peiming Liu
authored
[mlir][sparse] introduce sparse_tensor.coiterate operation. (llvm#101100)
This PR introduces `sparse_tensor.coiterate` operation, which represents a loop that traverses multiple sparse iteration space.
1 parent 38ef692 commit 785a24f

File tree

7 files changed

+569
-99
lines changed

7 files changed

+569
-99
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,37 +61,62 @@ struct COOSegment {
6161
/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
6262
/// by `sparse_tensor.iterate` operation for the set of levels on which the
6363
/// coordinates should be loaded.
64-
class LevelSet {
65-
uint64_t bits = 0;
64+
class I64BitSet {
65+
uint64_t storage = 0;
6666

6767
public:
68-
LevelSet() = default;
69-
explicit LevelSet(uint64_t bits) : bits(bits) {}
70-
operator uint64_t() const { return bits; }
68+
using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
69+
const_set_bits_iterator begin() const {
70+
return const_set_bits_iterator(*this);
71+
}
72+
const_set_bits_iterator end() const {
73+
return const_set_bits_iterator(*this, -1);
74+
}
75+
iterator_range<const_set_bits_iterator> bits() const {
76+
return make_range(begin(), end());
77+
}
78+
79+
I64BitSet() = default;
80+
explicit I64BitSet(uint64_t bits) : storage(bits) {}
81+
operator uint64_t() const { return storage; }
7182

72-
LevelSet &set(unsigned i) {
83+
I64BitSet &set(unsigned i) {
7384
assert(i < 64);
74-
bits |= static_cast<uint64_t>(0x01u) << i;
85+
storage |= static_cast<uint64_t>(0x01u) << i;
7586
return *this;
7687
}
7788

78-
LevelSet &operator|=(LevelSet lhs) {
79-
bits |= static_cast<uint64_t>(lhs);
89+
I64BitSet &operator|=(I64BitSet lhs) {
90+
storage |= static_cast<uint64_t>(lhs);
8091
return *this;
8192
}
8293

83-
LevelSet &lshift(unsigned offset) {
84-
bits = bits << offset;
94+
I64BitSet &lshift(unsigned offset) {
95+
storage = storage << offset;
8596
return *this;
8697
}
8798

99+
// Needed by `llvm::const_set_bits_iterator_impl`.
100+
int find_first() const { return min(); }
101+
int find_next(unsigned prev) const {
102+
if (prev >= max())
103+
return -1;
104+
105+
uint64_t b = storage >> (prev + 1);
106+
if (b == 0)
107+
return -1;
108+
109+
return llvm::countr_zero(b) + prev + 1;
110+
}
111+
88112
bool operator[](unsigned i) const {
89113
assert(i < 64);
90-
return (bits & (1 << i)) != 0;
114+
return (storage & (1 << i)) != 0;
91115
}
92-
unsigned max() const { return 64 - llvm::countl_zero(bits); }
93-
unsigned count() const { return llvm::popcount(bits); }
94-
bool empty() const { return bits == 0; }
116+
unsigned min() const { return llvm::countr_zero(storage); }
117+
unsigned max() const { return 64 - llvm::countl_zero(storage); }
118+
unsigned count() const { return llvm::popcount(storage); }
119+
bool empty() const { return storage == 0; }
95120
};
96121

97122
} // namespace sparse_tensor

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
2424
// sparse tensor levels.
2525
//===----------------------------------------------------------------------===//
2626

27-
def LevelSetAttr :
28-
TypedAttrBase<
29-
I64, "IntegerAttr",
27+
def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
3028
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
3129
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
3230
"LevelSet attribute"> {
33-
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34-
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
31+
let returnType = [{::mlir::sparse_tensor::I64BitSet}];
32+
let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
3533
}
3634

35+
def I64BitSetArrayAttr :
36+
TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;
37+
3738
//===----------------------------------------------------------------------===//
3839
// These attributes are just like `IndexAttr` except that they clarify whether
3940
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
13061306

13071307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
13081308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1309-
"ForeachOp", "IterateOp"]>]> {
1309+
"ForeachOp", "IterateOp", "CoIterateOp"]>]> {
13101310
let summary = "Yield from sparse_tensor set-like operations";
13111311
let description = [{
13121312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",
16291629

16301630
let arguments = (ins AnySparseIterSpace:$iterSpace,
16311631
Variadic<AnyType>:$initArgs,
1632-
LevelSetAttr:$crdUsedLvls);
1632+
I64BitSetAttr:$crdUsedLvls);
16331633
let results = (outs Variadic<AnyType>:$results);
16341634
let regions = (region SizedRegion<1>:$region);
16351635

16361636
let skipDefaultBuilders = 1;
16371637
let builders = [
16381638
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
1639-
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
1639+
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
16401640
];
16411641

16421642
let extraClassDeclaration = [{
@@ -1669,6 +1669,127 @@ def IterateOp : SparseTensor_Op<"iterate",
16691669
let hasCustomAssemblyFormat = 1;
16701670
}
16711671

1672+
def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
1673+
[AttrSizedOperandSegments,
1674+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
1675+
RecursiveMemoryEffects]> {
1676+
let summary = "Co-iterates over a set of sparse iteration spaces";
1677+
let description = [{
1678+
The `sparse_tensor.coiterate` operation represents a loop (nest) over
1679+
a set of iteration spaces. The operation can have multiple regions,
1680+
with each of them defining a case to compute a result at the current iterations.
1681+
The case condition is defined solely based on the pattern of specified iterators.
1682+
For example:
1683+
```mlir
1684+
%ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
1685+
: (!sparse_tensor.iter_space<#CSR, lvls = 0>,
1686+
!sparse_tensor.iter_space<#COO, lvls = 0>)
1687+
-> index
1688+
case %it1, _ {
1689+
// %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
1690+
}
1691+
case %it1, %it2 {
1692+
// %coord is specifed in *BOTH* spaces %sp1 and %sp2.
1693+
}
1694+
```
1695+
1696+
`sparse_tensor.coiterate` can also operate on loop-carried variables.
1697+
It returns the final value for each loop-carried variable after loop termination.
1698+
The initial values of the variables are passed as additional SSA operands
1699+
to the iterator SSA value and used coordinate SSA values.
1700+
Each operation region has variadic arguments for specified (used), one argument
1701+
for each loop-carried variable, representing the value of the variable
1702+
at the current iteration, followed by a list of arguments for iterators.
1703+
The body region must contain exactly one block that terminates with
1704+
`sparse_tensor.yield`.
1705+
1706+
The results of an `sparse_tensor.coiterate` hold the final values after
1707+
the last iteration. If the `sparse_tensor.coiterate` defines any values,
1708+
a yield must be explicitly present in every region defined in the operation.
1709+
The number and types of the `sparse_tensor.coiterate` results must match
1710+
the initial values in the iter_args binding and the yield operands.
1711+
1712+
1713+
A `sparse_tensor.coiterate` example that does elementwise addition between two
1714+
sparse vectors.
1715+
1716+
1717+
```mlir
1718+
%ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
1719+
: (!sparse_tensor.iter_space<#CSR, lvls = 0>,
1720+
!sparse_tensor.iter_space<#CSR, lvls = 0>)
1721+
-> tensor<?xindex, #CSR>
1722+
case %it1, _ {
1723+
// v = v1 + 0 = v1
1724+
%v1 = sparse_tensor.extract_value %t1 at %it1 : index
1725+
%yield = sparse_tensor.insert %v1 into %arg[%coord]
1726+
sparse_tensor.yield %yield
1727+
}
1728+
case _, %it2 {
1729+
// v = v2 + 0 = v2
1730+
%v2 = sparse_tensor.extract_value %t2 at %it2 : index
1731+
%yield = sparse_tensor.insert %v1 into %arg[%coord]
1732+
sparse_tensor.yield %yield
1733+
}
1734+
case %it1, %it2 {
1735+
// v = v1 + v2
1736+
%v1 = sparse_tensor.extract_value %t1 at %it1 : index
1737+
%v2 = sparse_tensor.extract_value %t2 at %it2 : index
1738+
%v = arith.addi %v1, %v2 : index
1739+
%yield = sparse_tensor.insert %v into %arg[%coord]
1740+
sparse_tensor.yield %yield
1741+
}
1742+
```
1743+
}];
1744+
1745+
let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
1746+
Variadic<AnyType>:$initArgs,
1747+
I64BitSetAttr:$crdUsedLvls,
1748+
I64BitSetArrayAttr:$cases);
1749+
let results = (outs Variadic<AnyType>:$results);
1750+
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
1751+
1752+
let extraClassDeclaration = [{
1753+
unsigned getSpaceDim() {
1754+
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
1755+
getIterSpaces().front().getType())
1756+
.getSpaceDim();
1757+
}
1758+
I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
1759+
return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
1760+
.getValue().getZExtValue());
1761+
}
1762+
auto getRegionDefinedSpaces() {
1763+
return llvm::map_range(getCases().getValue(), [](Attribute attr) {
1764+
return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
1765+
});
1766+
}
1767+
1768+
// The block arguments starts with referenced coordinates, follows by
1769+
// user-provided iteration arguments and ends with iterators.
1770+
Block::BlockArgListType getCrds(unsigned regionIdx) {
1771+
return getRegion(regionIdx).getArguments()
1772+
.take_front(getCrdUsedLvls().count());
1773+
}
1774+
unsigned getNumRegionIterArgs(unsigned regionIdx) {
1775+
return getInitArgs().size();
1776+
}
1777+
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
1778+
return getRegion(regionIdx).getArguments()
1779+
.slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
1780+
}
1781+
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
1782+
return getRegion(regionIdx).getArguments()
1783+
.take_back(getRegionDefinedSpace(regionIdx).count());
1784+
}
1785+
ValueRange getYieldedValues(unsigned regionIdx);
1786+
}];
1787+
1788+
let hasVerifier = 1;
1789+
let hasRegionVerifier = 1;
1790+
let hasCustomAssemblyFormat = 1;
1791+
}
1792+
16721793
//===----------------------------------------------------------------------===//
16731794
// Sparse Tensor Debugging and Test-Only Operations.
16741795
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)