Skip to content

[sparse] allow unpack op to return 0-ranked tensor type. #66269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 13, 2023

Conversation

PeimingLiu
Copy link
Member

Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases.

@PeimingLiu PeimingLiu requested a review from a team as a code owner September 13, 2023 18:17
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:sparse Sparse compiler in MLIR mlir labels Sep 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-sparse

Changes Many frontends canonicalize scalar into 0-ranked tensor, it change will hopefully make the operation easier to use for those cases. -- Full diff: https://github.com//pull/66269.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+4)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+15-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+2)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+3-2)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index e2f3df005b70d69..bf077db43ec10e9 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
 
 def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
 
+class ScalarLikeOf<list<Type> allowedTypes>
+  : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>]>;
+
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Sorting Algorithm Attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 7d9f1d3b26c0678..7430a3c6118cef4 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -108,8 +108,8 @@ def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure, SameVariadicResultS
                    Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
     Results<(outs TensorOf<[AnyType]>:$ret_values,
                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnySignlessIntegerOrIndex:$val_len,
-                  Variadic<AnySignlessIntegerOrIndex>:$lvl_lens)> {
+                  ScalarLikeOf<[AnySignlessIntegerOrIndex]>:$val_len,
+                  Variadic<ScalarLikeOf<[AnySignlessIntegerOrIndex]>>:$lvl_lens)> {
   let summary = "Returns the (values, coordinates) pair unpacked from the input tensor";
 
   let description = [{
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 0c8a304841c10d5..557c5c471c4a77c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -559,6 +559,18 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
   return reassociation;
 }
 
+static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+                               Type dstTp) {
+  if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+    // Scalars can only be converted to 0-ranked tensors.
+    if (rtp.getRank() != 0)
+      return nullptr;
+    elem = genCast(builder, loc, elem, rtp.getElementType());
+    return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+  }
+  return genCast(builder, loc, elem, dstTp);
+}
+
 //===----------------------------------------------------------------------===//
 // Codegen rules.
 //===----------------------------------------------------------------------===//
@@ -1324,7 +1336,8 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         // consistent.
         retMem.insert(retMem.begin(), dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(), genCast(rewriter, loc, sz, valLenTp));
+        retLen.insert(retLen.begin(),
+                      genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
@@ -1337,7 +1350,7 @@ struct SparseUnpackOpConverter : public OpConversionPattern<UnpackOp> {
         retMem.push_back(dst);
         // Retrieves the corresponding level length type.
         Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
-        retLen.push_back(genCast(rewriter, loc, sz, lvlLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
       }
       Value flatOut = dst;
       if (dst.getType().getRank() != 1) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index cce26bc603eeb3c..2956cf57ade0290 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -214,6 +214,8 @@ struct SparseTensorCodegenPass
     target.addLegalOp<GetStorageSpecifierOp>();
     target.addLegalOp<SetStorageSpecifierOp>();
     target.addLegalOp<StorageSpecifierInitOp>();
+    // tensor::FromElementsOp might be yield after lowering unpack.
+    target.addLegalOp<tensor::FromElementsOp>();
     // All dynamic rules below accept new function, call, return, and
     // various tensor and bufferization operations as legal output of the
     // rewriting provided that all sparse tensor types have been fully
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
index cc8d538e6adfb83..d95efb507765403 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir
@@ -219,7 +219,7 @@ module {
     %boi = tensor.empty() : tensor<6x2xindex>
     %bd, %bp, %bi, %ld, %lp, %li = sparse_tensor.unpack %bs : tensor<2x10x10xf64, #BCOO>
                     outs(%bod, %bop, %boi : tensor<6xf64>, tensor<4xindex>, tensor<6x2xindex>)
-                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, i64)
+                    -> tensor<6xf64>, (tensor<4xindex>, tensor<6x2xindex>), index, (i32, tensor<i64>)
 
     // CHECK-NEXT: ( 1, 2, 3, 4, 5, {{.*}} )
     %vbd = vector.transfer_read %bd[%c0], %f0 : tensor<6xf64>, vector<6xf64>
@@ -231,7 +231,8 @@ module {
     %vbi = vector.transfer_read %bi[%c0, %c0], %c0 : tensor<6x2xindex>, vector<6x2xindex>
     vector.print %vbi : vector<6x2xindex>
     // CHECK-NEXT: 10
-    vector.print %li : i64
+    %si = tensor.extract %li[] : tensor<i64>
+    vector.print %si : i64
 
     return
   }

@@ -438,6 +438,10 @@ class RankedSparseTensorOf<list<Type> allowedTypes>

def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;

class ScalarLikeOf<list<Type> allowedTypes>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically no longer a "sparse tensor trait" as defined by the header of this section (so in the long run we may want to promote this to a more general place). But OK for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree.

@PeimingLiu PeimingLiu merged commit 098f46d into llvm:main Sep 13, 2023
@PeimingLiu PeimingLiu deleted the pl-workspace branch September 13, 2023 18:33
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
Many frontends canonicalize scalar into 0-ranked tensor, it change will
hopefully make the operation easier to use for those cases.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants