Skip to content

Commit f3a8af0

Browse files
authored
[mlir][sparse] best effort finalization of escaping empty sparse tensors (#85482)
This change lifts the restriction that purely allocated empty sparse tensors cannot escape the method. Instead it makes a best effort to add a finalizing operation before the escape. This assumes that (1) we never build sparse tensors across method boundaries (e.g. allocate in one, insert in other method) (2) if we have other uses of the empty allocation in the same method, we assume that either that op will fail or will do the finalization for us. This is best-effort, but fixes some very obvious missing cases.
1 parent 43fc921 commit f3a8af0

File tree

5 files changed

+177
-44
lines changed

5 files changed

+177
-44
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,16 +252,6 @@ LogicalResult AllocTensorOp::verify() {
252252
<< getType().getNumDynamicDims() << " dynamic sizes";
253253
if (getCopy() && getCopy().getType() != getType())
254254
return emitError("expected that `copy` and return type match");
255-
256-
// For sparse tensor allocation, we require that none of its
257-
// uses escapes the function boundary directly.
258-
if (sparse_tensor::getSparseTensorEncoding(getType())) {
259-
for (auto &use : getOperation()->getUses())
260-
if (isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
261-
use.getOwner()))
262-
return emitError("sparse tensor allocation should not escape function");
263-
}
264-
265255
return success();
266256
}
267257

mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
1011
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1112
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
1213
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -16,6 +17,37 @@ using namespace mlir::sparse_tensor;
1617

1718
namespace {
1819

20+
struct GuardSparseAlloc
21+
: public OpRewritePattern<bufferization::AllocTensorOp> {
22+
using OpRewritePattern<bufferization::AllocTensorOp>::OpRewritePattern;
23+
24+
LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
25+
PatternRewriter &rewriter) const override {
26+
// Only rewrite sparse allocations.
27+
if (!getSparseTensorEncoding(op.getResult().getType()))
28+
return failure();
29+
30+
// Only rewrite sparse allocations that escape the method
31+
// without any chance of a finalizing operation in between.
32+
// Here we assume that sparse tensor setup never crosses
33+
// method boundaries. The current rewriting only repairs
34+
// the most obvious allocate-call/return cases.
35+
if (!llvm::all_of(op->getUses(), [](OpOperand &use) {
36+
return isa<func::ReturnOp, func::CallOp, func::CallIndirectOp>(
37+
use.getOwner());
38+
}))
39+
return failure();
40+
41+
// Guard escaping empty sparse tensor allocations with a finalizing
42+
// operation that leaves the underlying storage in a proper state
43+
// before the tensor escapes across the method boundary.
44+
rewriter.setInsertionPointAfter(op);
45+
auto load = rewriter.create<LoadOp>(op.getLoc(), op.getResult(), true);
46+
rewriter.replaceAllUsesExcept(op, load, load);
47+
return success();
48+
}
49+
};
50+
1951
template <typename StageWithSortOp>
2052
struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
2153
using OpRewritePattern<StageWithSortOp>::OpRewritePattern;
@@ -37,6 +69,6 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> {
3769
} // namespace
3870

3971
void mlir::populateStageSparseOperationsPatterns(RewritePatternSet &patterns) {
40-
patterns.add<StageUnorderedSparseOps<ConvertOp>,
72+
patterns.add<GuardSparseAlloc, StageUnorderedSparseOps<ConvertOp>,
4173
StageUnorderedSparseOps<ConcatenateOp>>(patterns.getContext());
4274
}

mlir/test/Dialect/Bufferization/invalid.mlir

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,6 @@ func.func @alloc_tensor_copy_and_dims(%t: tensor<?xf32>, %sz: index) {
2626

2727
// -----
2828

29-
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
30-
31-
func.func @sparse_alloc_direct_return() -> tensor<20x40xf32, #DCSR> {
32-
// expected-error @+1{{sparse tensor allocation should not escape function}}
33-
%0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
34-
return %0 : tensor<20x40xf32, #DCSR>
35-
}
36-
37-
// -----
38-
39-
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
40-
41-
func.func private @foo(tensor<20x40xf32, #DCSR>) -> ()
42-
43-
func.func @sparse_alloc_call() {
44-
// expected-error @+1{{sparse tensor allocation should not escape function}}
45-
%0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR>
46-
call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> ()
47-
return
48-
}
49-
50-
// -----
51-
5229
// expected-error @+1{{invalid value for 'bufferization.access'}}
5330
func.func private @invalid_buffer_access_type(tensor<*xf32> {bufferization.access = "foo"})
5431

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -868,16 +868,6 @@ func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (mem
868868

869869
// -----
870870

871-
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
872-
873-
func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
874-
// expected-error@+1 {{sparse tensor allocation should not escape function}}
875-
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSR>
876-
return %0: tensor<10x?xf64, #CSR>
877-
}
878-
879-
// -----
880-
881871
#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
882872
#OrderedCOOPerm = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton)}>
883873

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e main -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// RUN: %{compile} | %{run} | FileCheck %s
21+
//
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
//
26+
// Do the same run, but now with direct IR generation and vectorization.
27+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
28+
// RUN: %{compile} | %{run} | FileCheck %s
29+
//
30+
// Do the same run, but now with direct IR generation and VLA vectorization.
31+
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
32+
33+
34+
#map = affine_map<(d0) -> (d0)>
35+
36+
#SV = #sparse_tensor.encoding<{
37+
map = (d0) -> (d0 : compressed)
38+
}>
39+
40+
module {
41+
42+
// This directly yields an empty sparse vector.
43+
func.func @empty() -> tensor<10xf32, #SV> {
44+
%0 = tensor.empty() : tensor<10xf32, #SV>
45+
return %0 : tensor<10xf32, #SV>
46+
}
47+
48+
// This also directly yields an empty sparse vector.
49+
func.func @empty_alloc() -> tensor<10xf32, #SV> {
50+
%0 = bufferization.alloc_tensor() : tensor<10xf32, #SV>
51+
return %0 : tensor<10xf32, #SV>
52+
}
53+
54+
// This yields a hidden empty sparse vector (all zeros).
55+
func.func @zeros() -> tensor<10xf32, #SV> {
56+
%cst = arith.constant 0.0 : f32
57+
%0 = bufferization.alloc_tensor() : tensor<10xf32, #SV>
58+
%1 = linalg.generic {
59+
indexing_maps = [#map],
60+
iterator_types = ["parallel"]}
61+
outs(%0 : tensor<10xf32, #SV>) {
62+
^bb0(%out: f32):
63+
linalg.yield %cst : f32
64+
} -> tensor<10xf32, #SV>
65+
return %1 : tensor<10xf32, #SV>
66+
}
67+
68+
// This yields a filled sparse vector (all ones).
69+
func.func @ones() -> tensor<10xf32, #SV> {
70+
%cst = arith.constant 1.0 : f32
71+
%0 = bufferization.alloc_tensor() : tensor<10xf32, #SV>
72+
%1 = linalg.generic {
73+
indexing_maps = [#map],
74+
iterator_types = ["parallel"]}
75+
outs(%0 : tensor<10xf32, #SV>) {
76+
^bb0(%out: f32):
77+
linalg.yield %cst : f32
78+
} -> tensor<10xf32, #SV>
79+
return %1 : tensor<10xf32, #SV>
80+
}
81+
82+
//
83+
// Main driver.
84+
//
85+
func.func @main() {
86+
87+
%0 = call @empty() : () -> tensor<10xf32, #SV>
88+
%1 = call @empty_alloc() : () -> tensor<10xf32, #SV>
89+
%2 = call @zeros() : () -> tensor<10xf32, #SV>
90+
%3 = call @ones() : () -> tensor<10xf32, #SV>
91+
92+
//
93+
// Verify the output. In particular, make sure that
94+
// all empty sparse vector data structures are properly
95+
// finalized with a pair (0,0) for positions.
96+
//
97+
// CHECK: ---- Sparse Tensor ----
98+
// CHECK-NEXT: nse = 0
99+
// CHECK-NEXT: dim = ( 10 )
100+
// CHECK-NEXT: lvl = ( 10 )
101+
// CHECK-NEXT: pos[0] : ( 0, 0,
102+
// CHECK-NEXT: crd[0] : (
103+
// CHECK-NEXT: values : (
104+
// CHECK-NEXT: ----
105+
//
106+
// CHECK-NEXT: ---- Sparse Tensor ----
107+
// CHECK-NEXT: nse = 0
108+
// CHECK-NEXT: dim = ( 10 )
109+
// CHECK-NEXT: lvl = ( 10 )
110+
// CHECK-NEXT: pos[0] : ( 0, 0,
111+
// CHECK-NEXT: crd[0] : (
112+
// CHECK-NEXT: values : (
113+
// CHECK-NEXT: ----
114+
//
115+
// CHECK-NEXT: ---- Sparse Tensor ----
116+
// CHECK-NEXT: nse = 0
117+
// CHECK-NEXT: dim = ( 10 )
118+
// CHECK-NEXT: lvl = ( 10 )
119+
// CHECK-NEXT: pos[0] : ( 0, 0,
120+
// CHECK-NEXT: crd[0] : (
121+
// CHECK-NEXT: values : (
122+
// CHECK-NEXT: ----
123+
//
124+
// CHECK-NEXT: ---- Sparse Tensor ----
125+
// CHECK-NEXT: nse = 10
126+
// CHECK-NEXT: dim = ( 10 )
127+
// CHECK-NEXT: lvl = ( 10 )
128+
// CHECK-NEXT: pos[0] : ( 0, 10,
129+
// CHECK-NEXT: crd[0] : ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
130+
// CHECK-NEXT: values : ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
131+
// CHECK-NEXT: ----
132+
//
133+
sparse_tensor.print %0 : tensor<10xf32, #SV>
134+
sparse_tensor.print %1 : tensor<10xf32, #SV>
135+
sparse_tensor.print %2 : tensor<10xf32, #SV>
136+
sparse_tensor.print %3 : tensor<10xf32, #SV>
137+
138+
bufferization.dealloc_tensor %0 : tensor<10xf32, #SV>
139+
bufferization.dealloc_tensor %1 : tensor<10xf32, #SV>
140+
bufferization.dealloc_tensor %2 : tensor<10xf32, #SV>
141+
bufferization.dealloc_tensor %3 : tensor<10xf32, #SV>
142+
return
143+
}
144+
}

0 commit comments

Comments
 (0)