Skip to content

Commit d4db528

Browse files
author
Peiming Liu
committed
[mlir][sparse] extend unpack operation to support unpacking a batched COO type
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149103
1 parent f9fbda7 commit d4db528

File tree

13 files changed

+393
-125
lines changed

13 files changed

+393
-125
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,
124124
}
125125

126126
def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
127-
Arguments<(ins AnySparseTensor:$tensor)>,
128-
Results<(outs 1DTensorOf<[AnyType]>:$values,
129-
2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
127+
Arguments<(ins AnySparseTensor:$tensor,
128+
OptionalAttr<IndexAttr>:$batched_lvls)>,
129+
Results<(outs TensorOf<[AnyType]>:$values,
130+
TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates,
130131
AnySignlessIntegerOrIndex:$nse)> {
131132
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
132133

@@ -159,11 +160,44 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">,
159160
// %coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
160161
// %nse = 3
161162
```
163+
164+
If `batched_lvls` is provided, the operation unpacks each batch of the tensors
165+
separately. The returned `nse` is the maximum nse of all batches. For a batch with
166+
a smaller nse, trailing zeros are appended in the result.
167+
Example:
168+
169+
```mlir
170+
// input BCOO format |1.1, 2.2, 3.3, 0.0|
171+
// of 2x4 matrix |0.0, 1.2, 2.3, 0.0|
172+
%values, %coordinates, %nse = sparse_tensor.unpack %st batched_lvls=1
173+
: tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO>
174+
// %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ],
175+
// [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64>
176+
// %coordinates = arith.constant dense<[[ [0], [1], [2] ],
177+
// [ [1], [2], [0] ]> : tensor<2x3x1xindex>
178+
```
179+
}];
180+
181+
let extraClassDeclaration = [{
182+
/// Returns the number of leading levels that are batched.
183+
unsigned getNumBatchedLvls();
162184
}];
163185

186+
let builders = [
187+
OpBuilder<(ins "Type":$values, "Type":$coordinates, "Type":$nse, "Value": $tensor),
188+
[{
189+
build($_builder, $_state, values, coordinates, nse, tensor, nullptr);
190+
}]>,
191+
OpBuilder<(ins "TypeRange":$resultTypes, "Value": $tensor),
192+
[{
193+
build($_builder, $_state, resultTypes, tensor, nullptr);
194+
}]>
195+
];
196+
197+
164198
let assemblyFormat =
165-
"$tensor attr-dict `:` type($tensor)"
166-
"`to` type($values) `,` type($coordinates) `,` type($nse)";
199+
"$tensor (`batched_lvls` `=` $batched_lvls^)? attr-dict `:`"
200+
"type($tensor) `to` type($values) `,` type($coordinates) `,` type($nse)";
167201

168202
let hasVerifier = 1;
169203
}

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,11 @@ LogicalResult UnpackOp::verify() {
719719
const auto coordinatesTp = getRankedTensorType(getCoordinates());
720720
const auto srcTp = getSparseTensorType(getTensor());
721721
return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp,
722-
nullptr);
722+
getBatchedLvlsAttr());
723+
}
724+
725+
unsigned UnpackOp::getNumBatchedLvls() {
726+
return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0;
723727
}
724728

725729
LogicalResult ConvertOp::verify() {

mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,12 @@ struct UnpackOpInterface
153153
: public BufferizableOpInterface::ExternalModel<UnpackOpInterface,
154154
sparse_tensor::UnpackOp> {
155155
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
156-
// Similar to InsertOp, reallocation is not considered to allocate a new
157-
// piece of memory.
158-
return false;
156+
// We allocate and return unpacked memory if this is a batched unpack.
157+
// When the number of batched levels equals to zero, we reuse the
158+
// coordinates/values memref (and reallocation if the requested output size
159+
// is larger than the actual size). Similar to InsertOp, reallocation is
160+
// not considered to allocate a new piece of memory.
161+
return llvm::cast<UnpackOp>(op).getNumBatchedLvls() != 0;
159162
}
160163

161164
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
213213
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
214214
}
215215

216+
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
217+
Value s) {
218+
Value load = builder.create<memref::LoadOp>(loc, mem, s);
219+
if (!load.getType().isa<IndexType>()) {
220+
if (load.getType().getIntOrFloatBitWidth() < 64)
221+
load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
222+
load =
223+
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
224+
}
225+
return load;
226+
}
227+
216228
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
217229
if (tp.isa<FloatType>())
218230
return builder.getFloatAttr(tp, 1.0);

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
7575
/// Add type casting between arith and index types when needed.
7676
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
7777

78+
/// Generates a pointer/index load from the sparse storage scheme. Narrower
79+
/// data types need to be zero extended before casting the value into the
80+
/// index type used for looping and indexing.
81+
Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
82+
7883
/// Generates a 1-valued attribute of the given type. This supports
7984
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
8085
/// for unsupported types we raise `llvm_unreachable` rather than

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,6 @@ using namespace mlir::sparse_tensor;
4141
// File local helper functions.
4242
//===----------------------------------------------------------------------===//
4343

44-
/// Generates a pointer/index load from the sparse storage scheme. Narrower
45-
/// data types need to be zero extended before casting the value into the
46-
/// index type used for looping and indexing.
47-
static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
48-
Value s) {
49-
// For the scalar case, we simply zero extend narrower indices into 64-bit
50-
// values before casting to index without a performance penalty. Here too,
51-
// however, indices that already are 64-bit, in theory, cannot express the
52-
// full range as explained above.
53-
Value load = builder.create<memref::LoadOp>(loc, mem, s);
54-
if (!load.getType().isa<IndexType>()) {
55-
if (load.getType().getIntOrFloatBitWidth() < 64)
56-
load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
57-
load =
58-
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
59-
}
60-
return load;
61-
}
62-
6344
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
6445
Level lvl) {
6546
auto enc = getSparseTensorEncoding(tensor.getType());
@@ -707,7 +688,8 @@ Operation *LoopEmitter::enterLoopOverTensorAtLvl(
707688
continue;
708689
}
709690

710-
bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType);
691+
bool isSparse = isCompressedDLT(lvlType) || isSingletonDLT(lvlType) ||
692+
isCompressedWithHiDLT(lvlType);
711693
// We can at most have one sparse input, otherwise, a while loop is
712694
// required to co-iterate multiple sparse tensors.
713695
assert(!isSparseCond || !isSparse);

0 commit comments

Comments
 (0)