Skip to content

Commit 098f46d

Browse files
authored
[sparse] allow unpack op to return 0-ranked tensor type. (#66269)
Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.
1 parent 372115f commit 098f46d

File tree

5 files changed

+26
-6
lines changed

5 files changed

+26
-6
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
438438

439439
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
440440

441+
class ScalarLikeOf<list<Type> allowedTypes>
442+
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>]>;
443+
444+
441445
//===----------------------------------------------------------------------===//
442446
// Sparse Tensor Sorting Algorithm Attribute.
443447
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
108108
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
109109
Results<(outs TensorOf<[AnyType]>:$ret_values,
110110
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
111-
AnySignlessIntegerOrIndex:$val_len,
112-
Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
111+
ScalarLikeOf<[AnySignlessIntegerOrIndex]>:$val_len,
112+
Variadic<ScalarLikeOf<[AnySignlessIntegerOrIndex]>>:$lvl_lens)> {
113113
let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
114114

115115
let description = [{

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
559559
return reassociation;
560560
}
561561

562+
static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
563+
Type dstTp) {
564+
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
565+
// Scalars can only be converted to 0-ranked tensors.
566+
if (rtp.getRank() != 0)
567+
return nullptr;
568+
elem = genCast(builder, loc, elem, rtp.getElementType());
569+
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
570+
}
571+
return genCast(builder, loc, elem, dstTp);
572+
}
573+
562574
//===----------------------------------------------------------------------===//
563575
// Codegen rules.
564576
//===----------------------------------------------------------------------===//
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
13241336
// consistent.
13251337
retMem.insert(retMem.begin(), dst);
13261338
Type valLenTp = op.getValLen().getType();
1327-
retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
1339+
retLen.insert(retLen.begin(),
1340+
genScalarToTensor(rewriter, loc, sz, valLenTp));
13281341
} else {
13291342
assert(fKind == SparseTensorFieldKind::PosMemRef ||
13301343
fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
13371350
retMem.push_back(dst);
13381351
// Retrieves the corresponding level length type.
13391352
Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1340-
retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
1353+
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
13411354
}
13421355
Value flatOut = dst;
13431356
if (dst.getType().getRank() != 1) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ struct SparseTensorCodegenPass
214214
target.addLegalOp<GetStorageSpecifierOp>();
215215
target.addLegalOp<SetStorageSpecifierOp>();
216216
target.addLegalOp<StorageSpecifierInitOp>();
217+
// Note that tensor::FromElementsOp might be yield after lowering unpack.
218+
target.addLegalOp<tensor::FromElementsOp>();
217219
// All dynamic rules below accept new function, call, return, and
218220
// various tensor and bufferization operations as legal output of the
219221
// rewriting provided that all sparse tensor types have been fully

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ module {
219219
%boi = tensor.empty() : tensor<6x2xindex>
220220
%bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
221221
outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
222-
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
222+
-> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
223223

224224
// CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
225225
%vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,8 @@ module {
231231
%vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
232232
vector.print %vbi : vector<6x2xindex>
233233
// CHECK-NEXT: 10
234-
vector.print %li : i64
234+
%si = tensor.extract %li[] : tensor<i64>
235+
vector.print %si : i64
235236

236237
return
237238
}

0 commit comments

Comments
 (0)