Skip to content

[mlir][sparse] avoid incompatible linalg fuse-into-consumer #86752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2024

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Mar 27, 2024

This fixes an "infinite" loop bug, where the incoming IR was repeatedly rewritten while adding identical cast operations. The test for compatible types should include the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.

This fixes an "infinite" loop bug, where the incoming
IR was repeatedly rewritten while adding identical cast
operations. The test for compatible types should include
the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.
@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-sparse

Author: Aart Bik (aartbik)

Changes

This fixes an "infinite" loop bug, where the incoming IR was repeatedly rewritten while adding identical cast operations. The test for compatible types should include the notion of an encoding. If it differs, then a
naive fusion into the consumer is invalid.


Full diff: https://github.com/llvm/llvm-project/pull/86752.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+4)
  • (added) mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir (+47)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index dc8843aa4e1e13..38a9ad60bb7948 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -276,6 +276,10 @@ bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
   if (sourceType.getRank() != targetType.getRank())
     return false;
 
+  // Requires same encoding.
+  if (sourceType.getEncoding() != targetType.getEncoding())
+    return false;
+
   // If cast is towards more static sizes along any dimension, don't fold.
   for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
     if (!ShapedType::isDynamic(std::get<0>(t)) &&
diff --git a/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
new file mode 100644
index 00000000000000..bbc7f397e793fe
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/no_fold_into_consumer.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt %s --canonicalize --pre-sparsification-rewrite | FileCheck %s
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+#sparse = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) ->
+          (d0 : compressed(nonunique),
+	   d1 : singleton(nonunique, soa),
+	   d2 : singleton(soa)),
+  posWidth = 64,
+  crdWidth = 64
+}>
+
+
+module {
+  //
+  // This IR should not end up in an infinite loop trying to fold
+  // the linalg producer into the tensor cast consumer (even though
+  // static sizes can fold, the different encodings cannot). The
+  // cast was sloppy to begin with (but it has been observed by
+  // external sources) and can be easily repaired by the sparsifier.
+  //
+  // CHECK-LABEL: func @avoid_fold
+  // CHECK:       arith.constant
+  // CHECK:       tensor.empty()
+  // CHECK:       linalg.generic
+  // CHECK:       sparse_tensor.convert
+  // CHECK:       return
+  //
+  func.func @avoid_fold(%0: tensor<10x20x30xf64, #sparse>) -> tensor<10x20x30xf64, #sparse> {
+    %1 = tensor.empty() : tensor<10x20x30xf64>
+    %2 = linalg.generic { indexing_maps = [#map, #map],
+                          iterator_types = ["parallel", "parallel", "parallel"]
+                        }
+    ins (%0 : tensor<10x20x30xf64, #sparse>)
+    outs(%1 : tensor<10x20x30xf64>) {
+        ^bb0(%in: f64, %out: f64):
+          %cst = arith.constant 0.000000e+00 : f64
+          %4 = arith.cmpf ugt, %in, %cst : f64
+          %5 = arith.select %4, %in, %cst : f64
+          linalg.yield %5 : f64
+    } -> tensor<10x20x30xf64>
+    %cast = tensor.cast %2 : tensor<10x20x30xf64> to tensor<10x20x30xf64, #sparse>
+    return %cast : tensor<10x20x30xf64, #sparse>
+  }
+}
+

@aartbik aartbik merged commit 3324f4d into llvm:main Mar 27, 2024
@aartbik aartbik deleted the bik branch March 27, 2024 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants