Skip to content

[mlir][sparse] Improve sparse tensor type constraints #112133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Oct 13, 2024

Sparse tensors are always ranked tensors. Encodings cannot be attached to unranked tensors. Change the type constraint to RankedTensorOf, so that we generate TypedValue<RankedTensorType> instead of TypedValue<TensorType>. This removes the need for type casting in some cases.

Also improve the verifiers (missing return statements) and switch a few other AnyTensor to AnyRankedTensor.

This commit is in preparation of a dialect conversion commit that required fixes in the sparse dialect.

@llvmbot
Copy link
Member

llvmbot commented Oct 13, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Sparse tensors are always ranked tensors. Encodings cannot be attached to unranked tensors. Change the type constraint to RankedTensorOf, so that we generate TypedValue&lt;RankedTensorType&gt; instead of TypedValue&lt;TensorType&gt;. This removes the need for type casting in some cases.

Also improve the verifiers (missing return statements) and switch a few other AnyTensor to AnyRankedTensor.


Full diff: https://github.com/llvm/llvm-project/pull/112133.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+11-11)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+48-43)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+3-3)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index cb6c1b63e4e4b0..adcf6fac752fe6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -586,10 +586,10 @@ def IsSparseTensorSlicePred
           "  ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
 
 class SparseTensorOf<list<Type> allowedTypes>
-  : TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
+  : RankedTensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
 
 class SparseTensorSliceOf<list<Type> allowedTypes>
-  : TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
+  : RankedTensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
 
 class ScalarLikeOf<list<Type> allowedTypes>
   : AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 96a61419a541f7..2c281c9f6aa85d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> {
     ```
   }];
 
-  let arguments = (ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
-                   TensorOf<[AnyType]>:$values);
+  let arguments = (ins Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                       RankedTensorOf<[AnyType]>:$values);
   let results = (outs AnySparseTensor: $result);
   let assemblyFormat =
     "` ` `(` $levels       `)` `,` $values attr-dict `:`"
@@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
   }];
 
   let arguments = (ins AnySparseTensor:$tensor,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
-                   TensorOf<[AnyType]>:$out_values);
-  let results = (outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  TensorOf<[AnyType]>:$ret_values,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
-                  AnyIndexingScalarLike:$val_len);
+                       Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                       RankedTensorOf<[AnyType]>:$out_values);
+  let results = (outs Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                      RankedTensorOf<[AnyType]>:$ret_values,
+                      Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                      AnyIndexingScalarLike:$val_len);
   let assemblyFormat =
     "$tensor attr-dict `:` type($tensor)"
     "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
@@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
 
   }];
 
-  let arguments = (ins AnyTensor:$source);
-  let results = (outs AnyTensor:$dest);
+  let arguments = (ins AnyRankedTensor:$source);
+  let results = (outs AnyRankedTensor:$dest);
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 
   let extraClassDeclaration = [{
@@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
   ];
 
   let regions = (region SizedRegion<1>:$region);
-  let arguments = (ins AnyTensor:$tensor,
+  let arguments = (ins AnyRankedTensor:$tensor,
                        Variadic<AnyType>:$initArgs,
                        OptionalAttr<AffineMapAttr>:$order);
   let results = (outs Variadic<AnyType>:$results);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index b21bc1a93036c4..7b1b1f383e6343 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
     // The coordinates should be in shape of <? x rank>
     unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
     if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
-      op->emitError("input/output trailing COO level-ranks don't match");
+      return op->emitError("input/output trailing COO level-ranks don't match");
     }
   }
 
@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
 }
 
 LogicalResult AssembleOp::verify() {
-  const auto valuesTp = getRankedTensorType(getValues());
+  RankedTensorType valuesTp = getValues().getType();
   const auto lvlsTp = getLevels().getTypes();
   const auto resTp = getSparseTensorType(getResult());
   return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
     if (ot.getType() != rt.getType())
       return emitError("output levels and return levels type mismatch");
 
-  const auto valuesTp = getRankedTensorType(getRetValues());
+  RankedTensorType valuesTp = getRetValues().getType();
   const auto lvlsTp = getRetLevels().getTypes();
   const auto srcTp = getSparseTensorType(getTensor());
   return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
 }
 
 LogicalResult ConvertOp::verify() {
-  if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
-    if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
-      if (tp1.getRank() != tp2.getRank())
-        return emitError("unexpected conversion mismatch in rank");
-      auto dstEnc =
-          llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
-      if (dstEnc && dstEnc.isSlice())
-        return emitError("cannot convert to a sparse tensor slice");
-
-      auto shape1 = tp1.getShape();
-      auto shape2 = tp2.getShape();
-      // Accept size matches between the source and the destination type
-      // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
-      // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
-      for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
-        if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
-          return emitError("unexpected conversion mismatch in dimension ") << d;
-      return success();
-    }
-  }
-  return emitError("unexpected type in convert");
+  RankedTensorType tp1 = getSource().getType();
+  RankedTensorType tp2 = getDest().getType();
+  if (tp1.getRank() != tp2.getRank())
+    return emitError("unexpected conversion mismatch in rank");
+  auto dstEnc =
+      llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
+  if (dstEnc && dstEnc.isSlice())
+    return emitError("cannot convert to a sparse tensor slice");
+
+  auto shape1 = tp1.getShape();
+  auto shape2 = tp2.getShape();
+  // Accept size matches between the source and the destination type
+  // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
+  // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
+  for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
+    if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
+      return emitError("unexpected conversion mismatch in dimension ") << d;
+  return success();
 }
 
 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
   if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
     auto stt = getSparseTensorType(getSource());
     if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
-      emitError("Level index exceeds the rank of the input sparse tensor");
+      return emitError(
+          "Level index exceeds the rank of the input sparse tensor");
   }
   return success();
 }
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
 }
 
 LogicalResult ToSliceOffsetOp::verify() {
-  auto rank = getRankedTensorType(getSlice()).getRank();
+  auto rank = getSlice().getType().getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
     return emitError("requested dimension out of bound");
   return success();
 }
 
 LogicalResult ToSliceStrideOp::verify() {
-  auto rank = getRankedTensorType(getSlice()).getRank();
+  auto rank = getSlice().getType().getRank();
   if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
     return emitError("requested dimension out of bound");
   return success();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
   const auto iTp = IndexType::get(getContext());
   for (Dimension d = 0; d < dimRank; d++)
     if (args[d].getType() != iTp)
-      emitError(
+      return emitError(
           llvm::formatv("Expecting Index type for argument at index {0}", d));
 
   const auto elemTp = t.getElementType();
   const auto valueTp = args[dimRank].getType();
   if (elemTp != valueTp)
-    emitError(llvm::formatv("Unmatched element type between input tensor and "
-                            "block argument, expected:{0}, got: {1}",
-                            elemTp, valueTp));
+    return emitError(
+        llvm::formatv("Unmatched element type between input tensor and "
+                      "block argument, expected:{0}, got: {1}",
+                      elemTp, valueTp));
   return success();
 }
 
@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
   SparseTensorType dstStt = getSparseTensorType(getResultCoo());
 
   if (!srcStt.isCOOType() || !dstStt.isCOOType())
-    emitError("Expected COO sparse tensors only");
+    return emitError("Expected COO sparse tensors only");
 
   if (!srcStt.hasSameDimToLvl(dstStt))
-    emitError("Unmatched dim2lvl map between input and result COO");
+    return emitError("Unmatched dim2lvl map between input and result COO");
 
   if (srcStt.getPosType() != dstStt.getPosType() ||
       srcStt.getCrdType() != dstStt.getCrdType() ||
       srcStt.getElementType() != dstStt.getElementType())
-    emitError("Unmatched storage format between input and result COO");
+    return emitError("Unmatched storage format between input and result COO");
 
   return success();
 }
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
   AffineMap xPerm = getPermMap();
   uint64_t nx = xPerm.getNumDims();
   if (nx < 1)
-    emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+    return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
 
   if (!xPerm.isPermutation())
-    emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+    return emitError(
+        llvm::formatv("Expected a permutation map, got {0}", xPerm));
 
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
     return success();
 
   // Verify dimensions.
-  const auto checkDim = [&](Value v, Size minSize, const char *message) {
+  const auto checkDim = [&](Value v, Size minSize,
+                            const char *message) -> LogicalResult {
     const Size sh = getMemRefType(v).getShape()[0];
     if (!ShapedType::isDynamic(sh) && sh < minSize)
-      emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+      return emitError(
+          llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
+    return success();
   };
   uint64_t n = cn.value();
   uint64_t ny = 0;
   if (auto nyAttr = getNyAttr())
     ny = nyAttr.getInt();
-  checkDim(getXy(), n * (nx + ny),
-           "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
+  if (failed(checkDim(getXy(), n * (nx + ny),
+                      "Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
+    return failure();
   for (Value opnd : getYs())
-    checkDim(opnd, n, "Expected dimension(y) >= n");
+    if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
+      return failure();
 
   return success();
 }
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
   }
 
   if (lvlHi <= lvlLo)
-    parser.emitError(parser.getNameLoc(),
-                     "expect larger level upper bound than lower bound");
+    return parser.emitError(parser.getNameLoc(),
+                            "expect larger level upper bound than lower bound");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 737b736ba795fe..908d2d8aa83f7c 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
 
 func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   // expected-error@+1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
-  %0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+  %0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
 
@@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
 
 func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
   // expected-error@+1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
-  %0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
+  %0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
   return %0 : memref<?xindex>
 }
 
@@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
 // -----
 
 func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
-  // expected-error@+1 {{unexpected type in convert}}
+  // expected-error@+1 {{invalid kind of type specified}}
   %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
   return %0 : tensor<10xf32>
 }

@PeimingLiu
Copy link
Member

Thanks! LGTM!

@matthias-springer matthias-springer merged commit 77f8297 into main Oct 13, 2024
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/sparse_types branch October 13, 2024 19:12
DanielCChen pushed a commit to DanielCChen/llvm-project that referenced this pull request Oct 16, 2024
Sparse tensors are always ranked tensors. Encodings cannot be attached
to unranked tensors. Change the type constraint to `RankedTensorOf`, so
that we generate `TypedValue<RankedTensorType>` instead of
`TypedValue<TensorType>`. This removes the need for type casting in some
cases.

Also improve the verifiers (missing `return` statements) and switch a
few other `AnyTensor` to `AnyRankedTensor`.

This commit is in preparation of a dialect conversion commit that
required fixes in the sparse dialect.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants