Skip to content

Commit 9a3d60e

Browse files
committed
[mlir][bufferization][sparse] put restriction on sparse tensor allocation
Putting some direct use restrictions on tensor allocations in the sparse case enables the use of simplifying assumptions in the bufferization analysis. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D128463
1 parent 5a08280 commit 9a3d60e

File tree

6 files changed

+49
-5
lines changed

6 files changed

+49
-5
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,17 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
4949
Both dense and sparse tensor types are supported. The result of a
5050
`bufferization.alloc_tensor` is a tensor value that can be used like any
5151
other tensor value. In practice, it is often used as the "out" operand of
52-
another op. E.g.:
52+
another op. Sparse tensor allocations should always be used in a local
53+
construction operation and never escape the function boundary directly.
54+
55+
Example:
5356

5457
```mlir
5558
%c = bufferization.alloc_tensor [%d1, %d2] : tensor<?x?xf32, #SparseMatrix>
5659
%0 = linalg.matmul
5760
ins(%a, %b: tensor<?x?xf32, #SparseMatrix>, tensor<?x?xf32, #SparseMatrix>)
5861
outs(%c: tensor<?x?xf32, #SparseMatrix>) -> tensor<?x?xf32, #SparseMatrix>
62+
return %0 : tensor<?x?xf32, #SparseMatrix>
5963
```
6064
}];
6165

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1010
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1213
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1314
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
15+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1416
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1517
#include "mlir/IR/Matchers.h"
1618

@@ -250,6 +252,16 @@ LogicalResult AllocTensorOp::verify() {
250252
<< getType().getNumDynamicDims() << " dynamic sizes";
251253
if (getCopy() && getCopy().getType() != getType())
252254
return emitError("expected that `copy` and return type match");
255+
256+
// For sparse tensor allocation, we require that none of its
257+
// uses escapes the function boundary directly.
258+
if (sparse_tensor::getSparseTensorEncoding(getType())) {
259+
for (auto &use : getOperation()->getUses())
260+
if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
261+
use.getOwner()))
262+
return emitError("sparse tensor allocation should not escape function");
263+
}
264+
253265
return success();
254266
}
255267

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
1616
MLIRDialect
1717
MLIRFuncDialect
1818
MLIRIR
19+
MLIRSparseTensorDialect
1920
MLIRTensorDialect
2021
MLIRMemRefDialect
2122
)

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,28 @@ func.func @escape_attr_non_bufferizable(%m0: memref<?xf32>) {
5454
// expected-error @+1{{'bufferization.escape' only valid on bufferizable ops}}
5555
%0 = memref.cast %m0 {bufferization.escape = [true]} : memref<?xf32> to memref<10xf32>
5656
return
57-
}
57+
}
58+
59+
// -----
60+
61+
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
62+
63+
func.func @sparse_alloc_direct_return() -> tensor<20x40xf32, #DCSR> {
64+
// expected-error @+1{{sparse tensor allocation should not escape function}}
65+
%0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
66+
return %0 : tensor<20x40xf32, #DCSR>
67+
}
68+
69+
// -----
70+
71+
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
72+
73+
func.func private @foo(tensor<20x40xf32, #DCSR>) -> ()
74+
75+
func.func @sparse_alloc_call() {
76+
// expected-error @+1{{sparse tensor allocation should not escape function}}
77+
%0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
78+
call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> ()
79+
return
80+
}
81+

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor
136136
// CHECK: return %[[T]] : !llvm.ptr<i8>
137137
func.func @sparse_init(%arg0: index, %arg1: index) -> tensor<?x?xf64, #SparseMatrix> {
138138
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
139-
return %0 : tensor<?x?xf64, #SparseMatrix>
139+
%1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
140+
return %1 : tensor<?x?xf64, #SparseMatrix>
140141
}
141142

142143
// CHECK-LABEL: func @sparse_release(
@@ -580,6 +581,7 @@ func.func @sparse_out2(%arg0: tensor<?x?x?xf32, #SparseTensor>, %arg1: !llvm.ptr
580581
func.func @sparse_and_dense_init(%arg0: index, %arg1: index)
581582
-> (tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>) {
582583
%0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64, #SparseMatrix>
583-
%1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
584-
return %0, %1 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
584+
%1 = sparse_tensor.load %0 : tensor<?x?xf64, #SparseMatrix>
585+
%2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor<?x?xf64>
586+
return %1, %2 : tensor<?x?xf64, #SparseMatrix>, tensor<?x?xf64>
585587
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8958,6 +8958,7 @@ cc_library(
89588958
":IR",
89598959
":InferTypeOpInterface",
89608960
":MemRefDialect",
8961+
":SparseTensorDialect",
89618962
":Support",
89628963
":TensorDialect",
89638964
"//llvm:Support",

0 commit comments

Comments
 (0)