Skip to content

Commit a27d886

Browse files
[mlir][linalg][bufferize] Fix element-wise access optimization for sparse tensors (#87305)
`linalg.generic` ops with sparse tensors do not necessarily bufferize to element-wise access, because insertions into a sparse tensor may change the layout of (or reallocate) the underlying sparse data structures.
1 parent 3ae5c77 commit a27d886

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1212
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
1313
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1415
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1516
#include "mlir/IR/Dialect.h"
1617
#include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
110111
ArrayRef<OpOperand *> opOperands) const {
111112
auto linalgOp = cast<linalg::LinalgOp>(op);
112113

114+
// Accesses into sparse data structures are not necessarily elementwise.
115+
if (sparse_tensor::hasAnySparseOperand(linalgOp))
116+
return false;
117+
113118
// All loops must be parallel.
114119
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
115120
return false;

mlir/test/Dialect/SparseTensor/one_shot_bufferize_tensor_copy_insertion.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
7070
} -> tensor<10xf32>
7171
return %0, %argb : tensor<10xf32>, tensor<10xf32>
7272
}
73+
74+
#map = affine_map<(d0, d1) -> (d0, d1)>
75+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
76+
#map2 = affine_map<(d0, d1, d2) -> (d2, d1)>
77+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
78+
#sparse = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }>
79+
80+
// linalg.generic with sparse tensors does not necessarily bufferize to
81+
// element-wise access into the underlying sparse data structures.
82+
83+
// CHECK-LABEL: func @sparse_non_elementwise(
84+
func.func @sparse_non_elementwise(%arg0: tensor<64x64xf32, #sparse>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
85+
%cst = arith.constant 0.000000e+00 : f32
86+
// CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
87+
// CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
88+
%0 = bufferization.alloc_tensor() : tensor<64x64xf32>
89+
// CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
90+
%1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%0 : tensor<64x64xf32>) {
91+
^bb0(%out: f32):
92+
linalg.yield %cst : f32
93+
} -> tensor<64x64xf32>
94+
// CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
95+
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %arg2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%1 : tensor<64x64xf32>) {
96+
^bb0(%in: f32, %in_0: f32, %out: f32):
97+
%4 = arith.mulf %in, %in_0 : f32
98+
%5 = arith.addf %out, %4 : f32
99+
linalg.yield %5 : f32
100+
} -> tensor<64x64xf32>
101+
// CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
102+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %2 : tensor<64x64xf32, #sparse>, tensor<64x64xf32>) outs(%0 : tensor<64x64xf32>) attrs = {sorted = true} {
103+
^bb0(%in: f32, %in_0: f32, %out: f32):
104+
%4 = arith.mulf %in, %in_0 : f32
105+
linalg.yield %4 : f32
106+
} -> tensor<64x64xf32>
107+
return %3 : tensor<64x64xf32>
108+
}

0 commit comments

Comments
 (0)