-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors #87305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors #87305
Conversation
…rse tensors `linalg.generic` ops with sparse tensors do not necessarily bufferize to element-wise access, because insertions into a sparse tensor may change the layout of (or reallocate) the underlying sparse data structures.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) Changes
Full diff: https://github.com/llvm/llvm-project/pull/87305.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e91b4f637..899b8c87d0df77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
+ // Accesses into sparse data structures are not necessarily elementwise.
+ if (sparse_tensor::hasAnySparseOperand(linalgOp))
+ return false;
+
// All loops must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
diff --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
index 6c2292be161a53..b769acdc7825ce 100644
--- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
} -> tensor<10xf32>
return %0, %argb : tensor<10xf32>, tensor<10xf32>
}
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
+
+// linalg.generic with sparse tensors does not necessarily bufferize to
+// element-wise access into the underlying sparse data structures.
+
+// CHECK-LABEL: func @sparse_non_elementwise(
+func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
+ // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<64x64xf32>
+ // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
+ %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
+ } -> tensor<64x64xf32>
+ // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.mulf %in, %in_0 : f32
+ %5 = arith.addf %out, %4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<64x64xf32>
+ // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs = {sorted = true} {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.mulf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<64x64xf32>
+ return %3 : tensor<64x64xf32>
+}
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) Changes
Full diff: https://github.com/llvm/llvm-project/pull/87305.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e91b4f637..899b8c87d0df77 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
+ // Accesses into sparse data structures are not necessarily elementwise.
+ if (sparse_tensor::hasAnySparseOperand(linalgOp))
+ return false;
+
// All loops must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
diff --git a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
index 6c2292be161a53..b769acdc7825ce 100644
--- a/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
+++ b/mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir
@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
} -> tensor<10xf32>
return %0, %argb : tensor<10xf32>, tensor<10xf32>
}
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
+
+// linalg.generic with sparse tensors does not necessarily bufferize to
+// element-wise access into the underlying sparse data structures.
+
+// CHECK-LABEL: func @sparse_non_elementwise(
+func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
+ // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
+ %0 = bufferization.alloc_tensor() : tensor<64x64xf32>
+ // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
+ %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) {
+ ^bb0(%out: f32):
+ linalg.yield %cst : f32
+ } -> tensor<64x64xf32>
+ // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.mulf %in, %in_0 : f32
+ %5 = arith.addf %out, %4 : f32
+ linalg.yield %5 : f32
+ } -> tensor<64x64xf32>
+ // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs = {sorted = true} {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %4 = arith.mulf %in, %in_0 : f32
+ linalg.yield %4 : f32
+ } -> tensor<64x64xf32>
+ return %3 : tensor<64x64xf32>
+}
|
linalg.generic
ops with sparse tensors do not necessarily bufferize to element-wise access, because insertions into a sparse tensor may change the layout of (or reallocate) the underlying sparse data structures.