Skip to content

[mlir][sparse] introduce sparse_tensor.coiterate operation. #101100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,62 @@ struct COOSegment {
/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
/// by `sparse_tensor.iterate` operation for the set of levels on which the
/// coordinates should be loaded.
class LevelSet {
uint64_t bits = 0;
class I64BitSet {
uint64_t storage = 0;

public:
LevelSet() = default;
explicit LevelSet(uint64_t bits) : bits(bits) {}
operator uint64_t() const { return bits; }
using const_set_bits_iterator = llvm::const_set_bits_iterator_impl<I64BitSet>;
const_set_bits_iterator begin() const {
return const_set_bits_iterator(*this);
}
const_set_bits_iterator end() const {
return const_set_bits_iterator(*this, -1);
}
iterator_range<const_set_bits_iterator> bits() const {
return make_range(begin(), end());
}

I64BitSet() = default;
explicit I64BitSet(uint64_t bits) : storage(bits) {}
operator uint64_t() const { return storage; }

LevelSet &set(unsigned i) {
I64BitSet &set(unsigned i) {
assert(i < 64);
bits |= static_cast<uint64_t>(0x01u) << i;
storage |= static_cast<uint64_t>(0x01u) << i;
return *this;
}

LevelSet &operator|=(LevelSet lhs) {
bits |= static_cast<uint64_t>(lhs);
I64BitSet &operator|=(I64BitSet lhs) {
storage |= static_cast<uint64_t>(lhs);
return *this;
}

LevelSet &lshift(unsigned offset) {
bits = bits << offset;
I64BitSet &lshift(unsigned offset) {
storage = storage << offset;
return *this;
}

// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
if (prev >= max())
return -1;

uint64_t b = storage >> (prev + 1);
if (b == 0)
return -1;

return llvm::countr_zero(b) + prev + 1;
}

bool operator[](unsigned i) const {
assert(i < 64);
return (bits & (1 << i)) != 0;
return (storage & (1 << i)) != 0;
}
unsigned max() const { return 64 - llvm::countl_zero(bits); }
unsigned count() const { return llvm::popcount(bits); }
bool empty() const { return bits == 0; }
unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
};

} // namespace sparse_tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ class SparseTensor_Attr<string name,
// sparse tensor levels.
//===----------------------------------------------------------------------===//

def LevelSetAttr :
TypedAttrBase<
I64, "IntegerAttr",
def I64BitSetAttr : TypedAttrBase<I64, "IntegerAttr",
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
"LevelSet attribute"> {
let returnType = [{::mlir::sparse_tensor::LevelSet}];
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
let returnType = [{::mlir::sparse_tensor::I64BitSet}];
let convertFromStorage = [{::mlir::sparse_tensor::I64BitSet($_self.getValue().getZExtValue())}];
}

def I64BitSetArrayAttr :
TypedArrayAttrBase<I64BitSetAttr, "I64BitSet array attribute">;

//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
Expand Down
127 changes: 124 additions & 3 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu

def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
"ForeachOp", "IterateOp"]>]> {
"ForeachOp", "IterateOp", "CoIterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
Expand Down Expand Up @@ -1629,14 +1629,14 @@ def IterateOp : SparseTensor_Op<"iterate",

let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs,
LevelSetAttr:$crdUsedLvls);
I64BitSetAttr:$crdUsedLvls);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs)>,
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "LevelSet" :$crdUsedLvls)>
OpBuilder<(ins "Value":$iterSpace, "ValueRange":$initArgs, "I64BitSet" :$crdUsedLvls)>
];

let extraClassDeclaration = [{
Expand Down Expand Up @@ -1669,6 +1669,127 @@ def IterateOp : SparseTensor_Op<"iterate",
let hasCustomAssemblyFormat = 1;
}

def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
[AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "Co-iterates over a set of sparse iteration spaces";
let description = [{
The `sparse_tensor.coiterate` operation represents a loop (nest) over
a set of iteration spaces. The operation can have multiple regions,
with each of them defining a case to compute a result at the current iterations.
The case condition is defined solely based on the pattern of specified iterators.
For example:
```mlir
%ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
: (!sparse_tensor.iter_space<#CSR, lvls = 0>,
!sparse_tensor.iter_space<#COO, lvls = 0>)
-> index
case %it1, _ {
// %coord is specifed in space %sp1 but *NOT* specified in space %sp2.
}
case %it1, %it2 {
// %coord is specifed in *BOTH* spaces %sp1 and %sp2.
}
```

`sparse_tensor.coiterate` can also operate on loop-carried variables.
It returns the final value for each loop-carried variable after loop termination.
The initial values of the variables are passed as additional SSA operands
to the iterator SSA value and used coordinate SSA values.
Each operation region has variadic arguments for specified (used), one argument
for each loop-carried variable, representing the value of the variable
at the current iteration, followed by a list of arguments for iterators.
The body region must contain exactly one block that terminates with
`sparse_tensor.yield`.

The results of an `sparse_tensor.coiterate` hold the final values after
the last iteration. If the `sparse_tensor.coiterate` defines any values,
a yield must be explicitly present in every region defined in the operation.
The number and types of the `sparse_tensor.coiterate` results must match
the initial values in the iter_args binding and the yield operands.


A `sparse_tensor.coiterate` example that does elementwise addition between two
sparse vectors.


```mlir
%ret = sparse_tensor.coiterate (%sp1, %sp2) at(%coord) iter_args(%arg = %init)
: (!sparse_tensor.iter_space<#CSR, lvls = 0>,
!sparse_tensor.iter_space<#CSR, lvls = 0>)
-> tensor<?xindex, #CSR>
case %it1, _ {
// v = v1 + 0 = v1
%v1 = sparse_tensor.extract_value %t1 at %it1 : index
%yield = sparse_tensor.insert %v1 into %arg[%coord]
sparse_tensor.yield %yield
}
case _, %it2 {
// v = v2 + 0 = v2
%v2 = sparse_tensor.extract_value %t2 at %it2 : index
%yield = sparse_tensor.insert %v1 into %arg[%coord]
sparse_tensor.yield %yield
}
case %it1, %it2 {
// v = v1 + v2
%v1 = sparse_tensor.extract_value %t1 at %it1 : index
%v2 = sparse_tensor.extract_value %t2 at %it2 : index
%v = arith.addi %v1, %v2 : index
%yield = sparse_tensor.insert %v into %arg[%coord]
sparse_tensor.yield %yield
}
```
}];

let arguments = (ins Variadic<AnySparseIterSpace>:$iterSpaces,
Variadic<AnyType>:$initArgs,
I64BitSetAttr:$crdUsedLvls,
I64BitSetArrayAttr:$cases);
let results = (outs Variadic<AnyType>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);

let extraClassDeclaration = [{
unsigned getSpaceDim() {
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
getIterSpaces().front().getType())
.getSpaceDim();
}
I64BitSet getRegionDefinedSpace(unsigned regionIdx) {
return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
.getValue().getZExtValue());
}
auto getRegionDefinedSpaces() {
return llvm::map_range(getCases().getValue(), [](Attribute attr) {
return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
});
}

// The block arguments starts with referenced coordinates, follows by
// user-provided iteration arguments and ends with iterators.
Block::BlockArgListType getCrds(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
.take_front(getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs(unsigned regionIdx) {
return getInitArgs().size();
}
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
.slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
}
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);
}];

let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading