-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[sparse] allow unpack op to return 0-ranked tensor type. #66269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-sparse ChangesMany frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases. -- Full diff: https://github.com//pull/66269.diff5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td index e2f3df005b70d69..bf077db43ec10e9 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes> def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>; +class ScalarLikeOf<list<Type> allowedTypes> + : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>]>; + + //===----------------------------------------------------------------------===// // Sparse Tensor Sorting Algorithm Attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 7d9f1d3b26c0678..7430a3c6118cef4 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>, Results<(outs TensorOf<[AnyType]>:$ret_values, Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels, - AnySignlessIntegerOrIndex:$val_len, - Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> { + ScalarLikeOf<[AnySignlessIntegerOrIndex]>:$val_len, + Variadic<ScalarLikeOf<[AnySignlessIntegerOrIndex]>>:$lvl_lens)> { let summary = "Returns the (values, coordinates) pair unpacked from the input tensor"; let description = [{ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 0c8a304841c10d5..557c5c471c4a77c 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { return reassociation; } +static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, + Type dstTp) { + if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { + // Scalars can only be converted to 0-ranked tensors. + if (rtp.getRank() != 0) + return nullptr; + elem = genCast(builder, loc, elem, rtp.getElementType()); + return builder.create<tensor::FromElementsOp>(loc, rtp, elem); + } + return genCast(builder, loc, elem, dstTp); +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> { // consistent. retMem.insert(retMem.begin(), dst); Type valLenTp = op.getValLen().getType(); - retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp)); + retLen.insert(retLen.begin(), + genScalarToTensor(rewriter, loc, sz, valLenTp)); } else { assert(fKind == SparseTensorFieldKind::PosMemRef || fKind == SparseTensorFieldKind::CrdMemRef); @@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> { retMem.push_back(dst); // Retrieves the corresponding level length type. Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()]; - retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp)); + retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp)); } Value flatOut = dst; if (dst.getType().getRank() != 1) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index cce26bc603eeb3c..2956cf57ade0290 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -214,6 +214,8 @@ struct SparseTensorCodegenPass target.addLegalOp<GetStorageSpecifierOp>(); target.addLegalOp<SetStorageSpecifierOp>(); target.addLegalOp<StorageSpecifierInitOp>(); + // tensor::FromElementsOp might be yield after lowering unpack. + target.addLegalOp<tensor::FromElementsOp>(); // All dynamic rules below accept new function, call, return, and // various tensor and bufferization operations as legal output of the // rewriting provided that all sparse tensor types have been fully diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir index cc8d538e6adfb83..d95efb507765403 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -219,7 +219,7 @@ module { %boi = tensor.empty() : tensor<6x2xindex> %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO> outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>) - -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64) + -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>) // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} ) %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64> @@ -231,7 +231,8 @@ module { %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex> vector.print %vbi : vector<6x2xindex> // CHECK-NEXT: 10 - vector.print %li : i64 + %si = tensor.extract %li[] : tensor<i64> + vector.print %si : i64 return } |
@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes> | |||
|
|||
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>; | |||
|
|||
class ScalarLikeOf<list<Type> allowedTypes> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is technically no longer a "sparse tensor trait" as defined by the header of this section (so in the long run we may want to promote this to a more general place). But OK for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agree.
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Outdated
Show resolved
Hide resolved
Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.
Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.