Skip to content

[mlir][sparse] add verification of absent value in sparse_tensor.unary #70248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 25, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Oct 25, 2023

This value should always be a plain contant or something invariant computed outside the surrounding linalg operation, since there is no co-iteration defined on anything done in this branch.

Fixes:
#69395

This value should always be a plain contant or something invariant
computed outside the surrounding linalg operation, since there is
no co-iteration defined on anything done in this branch.

Fixes:
llvm#69395
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Oct 25, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 25, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

This value should always be a plain contant or something invariant computed outside the surrounding linalg operation, since there is no co-iteration defined on anything done in this branch.

Fixes:
#69395


Full diff: https://github.com/llvm/llvm-project/pull/70248.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+43-40)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+21-17)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+51)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8c33e8651b1694e..50f5e7335dc923b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -624,11 +624,11 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
   string summary = "Inserts a value into the sparse tensor";
   string description = [{
     Inserts the value into the underlying storage of the tensor at the
-    given level-coordinates.  The arity of `lvlCoords` must match the
-    level-rank of the tensor.  This operation can only be applied when
-    the tensor materializes unintialized from a `bufferization.alloc_tensor`
-    operation and the final tensor is constructed with a `load` operation
-    which has the `hasInserts` attribute set.
+    given level-coordinates. The arity of `lvlCoords` must match the
+    level-rank of the tensor. This operation can only be applied when
+    the tensor materializes unintialized from a `tensor.empty` operation
+    and the final tensor is constructed with a `load` operation which
+    has the `hasInserts` attribute set.
 
     The level-properties of the sparse tensor type fully describe what
     kind of insertion order is allowed.  When all levels have "unique"
@@ -974,7 +974,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
       Example of isEqual applied to intersecting elements only:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %0 = linalg.generic #trait
         ins(%A: tensor<?xf64, #SparseVector>,
             %B: tensor<?xf64, #SparseVector>)
@@ -996,7 +996,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
       Example of A+B in upper triangle, A-B in lower triangle:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %1 = linalg.generic #trait
         ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64, #CSR>
         outs(%C: tensor<?x?xf64, #CSR> {
@@ -1029,7 +1029,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
       because we never use its values, only its sparse structure:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %2 = linalg.generic #trait
         ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xi32, #CSR>
         outs(%C: tensor<?x?xf64, #CSR> {
@@ -1069,7 +1069,9 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
       Each region contains a single block describing the computation and result.
       A non-empty block must end with a sparse_tensor.yield and the return type
       must match the type of `output`. The primary region's block has one
-      argument, while the missing region's block has zero arguments.
+      argument, while the missing region's block has zero arguments. The
+      absent region may only generate constants or values already computed
+      on entry of the `linalg.generic` operation.
 
       A region may also be declared empty (i.e. `absent={}`), indicating that the
       region does not contribute to the output.
@@ -1082,17 +1084,17 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
       Example of A+1, restricted to existing elements:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...) : tensor<?xf64, #SparseVector>
       %0 = linalg.generic #trait
          ins(%A: tensor<?xf64, #SparseVector>)
         outs(%C: tensor<?xf64, #SparseVector>) {
         ^bb0(%a: f64, %c: f64) :
           %result = sparse_tensor.unary %a : f64 to f64
             present={
-              ^bb0(%arg0: f64):
-                %cf1 = arith.constant 1.0 : f64
-                %ret = arith.addf %arg0, %cf1 : f64
-                sparse_tensor.yield %ret : f64
+            ^bb0(%arg0: f64):
+              %cf1 = arith.constant 1.0 : f64
+              %ret = arith.addf %arg0, %cf1 : f64
+              sparse_tensor.yield %ret : f64
             }
             absent={}
           linalg.yield %result : f64
@@ -1102,41 +1104,42 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
       Example returning +1 for existing values and -1 for missing values:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %p1 = arith.constant  1 : i32
+      %m1 = arith.constant -1 : i32
+      %C = tensor.empty(...) : tensor<?xi32, #SparseVector>
       %1 = linalg.generic #trait
          ins(%A: tensor<?xf64, #SparseVector>)
-        outs(%C: tensor<?xf64, #SparseVector>) {
-        ^bb0(%a: f64, %c: f64) :
+        outs(%C: tensor<?xi32, #SparseVector>) {
+        ^bb0(%a: f64, %c: i32) :
           %result = sparse_tensor.unary %a : f64 to i32
             present={
             ^bb0(%x: f64):
-              %ret = arith.constant 1 : i32
-              sparse_tensor.yield %ret : i32
-          }
-          absent={
-            %ret = arith.constant -1 : i32
-            sparse_tensor.yield %ret : i32
-          }
-          linalg.yield %result : f64
-      } -> tensor<?xf64, #SparseVector>
+              sparse_tensor.yield %p1 : i32
+            }
+            absent={
+              sparse_tensor.yield %m1 : i32
+            }
+          linalg.yield %result : i32
+      } -> tensor<?xi32, #SparseVector>
       ```
 
       Example showing a structural inversion (existing values become missing in
       the output, while missing values are filled with 1):
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %c1 = arith.constant 1 : i64
+      %C = tensor.empty(...) : tensor<?xi64, #SparseVector>
       %2 = linalg.generic #trait
-          ins(%A: tensor<?xf64, #SparseVector>)
-          outs(%C: tensor<?xf64, #SparseVector>) {
-            %result = sparse_tensor.unary %a : f64 to i64
-              present={}
-              absent={
-                %ret = arith.constant 1 : i64
-                sparse_tensor.yield %ret : i64
-              }
-          linalg.yield %result : f64
-      } -> tensor<?xf64, #SparseVector>
+         ins(%A: tensor<?xf64, #SparseVector>)
+        outs(%C: tensor<?xi64, #SparseVector>) {
+        ^bb0(%a: f64, %c: i64) :
+          %result = sparse_tensor.unary %a : f64 to i64
+            present={}
+            absent={
+              sparse_tensor.yield %c1 : i64
+            }
+          linalg.yield %result : i64
+      } -> tensor<?xi64, #SparseVector>
       ```
   }];
 
@@ -1177,7 +1180,7 @@ def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [Pure, SameOperandsAndResu
       ```mlir
       %cf1 = arith.constant 1.0 : f64
       %cf100 = arith.constant 100.0 : f64
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %0 = linalg.generic #trait
          ins(%A: tensor<?x?xf64, #SparseMatrix>)
         outs(%C: tensor<?xf64, #SparseVector>) {
@@ -1220,7 +1223,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
       Example of selecting A >= 4.0:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %0 = linalg.generic #trait
          ins(%A: tensor<?xf64, #SparseVector>)
         outs(%C: tensor<?xf64, #SparseVector>) {
@@ -1238,7 +1241,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
       Example of selecting lower triangle of a matrix:
 
       ```mlir
-      %C = bufferization.alloc_tensor...
+      %C = tensor.empty(...)
       %1 = linalg.generic #trait
          ins(%A: tensor<?x?xf64, #CSR>)
         outs(%C: tensor<?x?xf64, #CSR>) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 17e6ef53fe596e0..f05cbd8d16d9a76 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -34,8 +34,13 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
+#define RETURN_FAILURE_IF_FAILED(X)                                            \
+  if (failed(X)) {                                                             \
+    return failure();                                                          \
+  }
+
 //===----------------------------------------------------------------------===//
-// Additional convenience methods.
+// Local convenience methods.
 //===----------------------------------------------------------------------===//
 
 static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -52,7 +57,7 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
 }
 
 //===----------------------------------------------------------------------===//
-// StorageLayout
+// SparseTensorDialect StorageLayout.
 //===----------------------------------------------------------------------===//
 
 static constexpr Level kInvalidLevel = -1u;
@@ -183,7 +188,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
 }
 
 //===----------------------------------------------------------------------===//
-// TensorDialect Attribute Methods.
+// SparseTensorDialect Attribute Methods.
 //===----------------------------------------------------------------------===//
 
 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
@@ -658,11 +663,6 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-#define RETURN_FAILURE_IF_FAILED(X)                                            \
-  if (failed(X)) {                                                             \
-    return failure();                                                          \
-  }
-
 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
     ArrayRef<DynSize> dimShape, Type elementType,
     function_ref<InFlightDiagnostic()> emitError) const {
@@ -685,7 +685,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
 }
 
 //===----------------------------------------------------------------------===//
-// Convenience Methods.
+// Convenience methods.
 //===----------------------------------------------------------------------===//
 
 SparseTensorEncodingAttr
@@ -1365,10 +1365,6 @@ LogicalResult SetStorageSpecifierOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TensorDialect Linalg.Generic Operations.
-//===----------------------------------------------------------------------===//
-
 template <class T>
 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
                                         const char *regionName,
@@ -1445,6 +1441,18 @@ LogicalResult UnaryOp::verify() {
   if (!absent.empty()) {
     RETURN_FAILURE_IF_FAILED(
         verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
+    // Absent branch can only yield invariant values.
+    Block *absentBlock = &absent.front();
+    Block *parent = getOperation()->getBlock();
+    Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
+    if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
+      if (arg.getOwner() == parent)
+        return emitError("absent region cannot yield linalg argument");
+    } else if (Operation *def = absentVal.getDefiningOp()) {
+      if (!isa<arith::ConstantOp>(def) &&
+          (def->getBlock() == absentBlock || def->getBlock() == parent))
+        return emitError("absent region cannot yield locally computed value");
+    }
   }
   return success();
 }
@@ -1719,10 +1727,6 @@ LogicalResult YieldOp::verify() {
 
 #undef RETURN_FAILURE_IF_FAILED
 
-//===----------------------------------------------------------------------===//
-// TensorDialect Methods.
-//===----------------------------------------------------------------------===//
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 33aa81c5a747d9b..0217ef152be6a0d 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -544,6 +544,57 @@ func.func @invalid_unary_wrong_yield(%arg0: f64) -> f64 {
 
 // -----
 
+
+#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+
+#trait = {
+  indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
+  iterator_types = ["parallel"]
+}
+
+func.func @invalid_absent_value(%arg0 : tensor<100xf64, #SparseVector>) -> tensor<100xf64, #SparseVector> {
+  %C = tensor.empty() : tensor<100xf64, #SparseVector>
+  %0 = linalg.generic #trait
+    ins(%arg0: tensor<100xf64, #SparseVector>)
+    outs(%C: tensor<100xf64, #SparseVector>) {
+     ^bb0(%a: f64, %c: f64) :
+        // expected-error@+1 {{absent region cannot yield linalg argument}}
+        %result = sparse_tensor.unary %a : f64 to f64
+           present={}
+           absent={ sparse_tensor.yield %a : f64 }
+        linalg.yield %result : f64
+    } -> tensor<100xf64, #SparseVector>
+  return %0 : tensor<100xf64, #SparseVector>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+
+#trait = {
+  indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
+  iterator_types = ["parallel"]
+}
+
+func.func @invalid_absent_computation(%arg0 : tensor<100xf64, #SparseVector>) -> tensor<100xf64, #SparseVector> {
+  %f0 = arith.constant 0.0 : f64
+  %C = tensor.empty() : tensor<100xf64, #SparseVector>
+  %0 = linalg.generic #trait
+    ins(%arg0: tensor<100xf64, #SparseVector>)
+    outs(%C: tensor<100xf64, #SparseVector>) {
+     ^bb0(%a: f64, %c: f64) :
+        %v = arith.addf %a, %f0 : f64
+        // expected-error@+1 {{absent region cannot yield locally computed value}}
+        %result = sparse_tensor.unary %a : f64 to f64
+           present={}
+           absent={ sparse_tensor.yield %v : f64 }
+        linalg.yield %result : f64
+    } -> tensor<100xf64, #SparseVector>
+  return %0 : tensor<100xf64, #SparseVector>
+}
+
+// -----
+
 func.func @invalid_reduce_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 {
   %cf1 = arith.constant 1.0 : f64
   // expected-error@+1 {{reduce region must have exactly 2 arguments}}

@aartbik aartbik merged commit 7e83a1a into llvm:main Oct 25, 2023
@aartbik aartbik deleted the bik branch October 25, 2023 20:56
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
llvm#70248)

This value should always be a plain contant or something invariant
computed outside the surrounding linalg operation, since there is no
co-iteration defined on anything done in this branch.

Fixes:
llvm#69395
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants