Skip to content

Commit c0d78c4

Browse files
authored
[mlir][sparse] Implement rewriters to reinterpret maps on alloc_tenso… (#70993)
…r operation
1 parent 2667dbf commit c0d78c4

File tree

3 files changed

+156
-4
lines changed

3 files changed

+156
-4
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "CodegenUtils.h"
10+
911
#include "mlir/Dialect/Affine/IR/AffineOps.h"
12+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1013
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1114
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1215
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -188,6 +191,56 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
188191
}
189192
};
190193

194+
struct TensorAllocDemapper
195+
: public OpRewritePattern<bufferization::AllocTensorOp> {
196+
using OpRewritePattern::OpRewritePattern;
197+
LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
198+
PatternRewriter &rewriter) const override {
199+
if (!hasNonIdentityOperandsOrResults(op))
200+
return failure();
201+
202+
Location loc = op.getLoc();
203+
auto stt = getSparseTensorType(op.getResult());
204+
205+
SmallVector<Value> maxDimCrds;
206+
maxDimCrds.reserve(stt.getDimRank());
207+
ValueRange dynSz = op.getDynamicSizes();
208+
for (int64_t dimSz : stt.getDimShape()) {
209+
if (ShapedType::isDynamic(dimSz)) {
210+
Value maxCrd = rewriter.create<arith::SubIOp>(
211+
loc, dynSz.front(), constantIndex(rewriter, loc, 1));
212+
maxDimCrds.push_back(maxCrd);
213+
dynSz = dynSz.drop_front();
214+
} else {
215+
maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
216+
}
217+
}
218+
219+
ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
220+
CrdTransDirectionKind::dim2lvl);
221+
auto lvlShape = stt.getLvlShape();
222+
SmallVector<Value> dynLvlSzs;
223+
for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
224+
if (ShapedType::isDynamic(lvlShape[i])) {
225+
Value sz = rewriter.create<arith::AddIOp>(
226+
loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
227+
dynLvlSzs.push_back(sz);
228+
}
229+
}
230+
231+
assert(dynSz.empty()); // should have consumed all.
232+
rewriter.startRootUpdate(op);
233+
op->setOperands(dynLvlSzs);
234+
op.getResult().setType(stt.getDemappedType());
235+
rewriter.finalizeRootUpdate(op);
236+
rewriter.setInsertionPointAfter(op);
237+
238+
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
239+
rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
240+
return success();
241+
}
242+
};
243+
191244
struct TensorInsertDemapper
192245
: public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
193246
using DemapInsRewriter::DemapInsRewriter;
@@ -309,7 +362,7 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
309362
}
310363
if (scope == ReinterpretMapScope::kAll ||
311364
scope == ReinterpretMapScope::kExceptGeneric) {
312-
patterns.add<TensorInsertDemapper, ForeachOpDemapper>(
365+
patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
313366
patterns.getContext());
314367
}
315368
}

mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,9 @@ func.func @mul(%arg0: tensor<32x32xf32>,
5757

5858
// CHECK-LABEL: func.func @sparse_foreach_reinterpret_map(
5959
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64
60-
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<2x4xf64
60+
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() : tensor<1x2x2x2xf64
6161
// CHECK: %[[VAL_2:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64
62-
// CHECK: %[[VAL_3:.*]] = sparse_tensor.reinterpret_map %[[VAL_1]] : tensor<2x4xf64
63-
// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_3]])
62+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.foreach in %[[VAL_2]] init(%[[VAL_1]])
6463
// CHECK: ^bb0(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: f64, %[[VAL_10:.*]]: tensor<1x2x2x2xf64
6564
// CHECK: %[[VAL_11:.*]] = sparse_tensor.insert %[[VAL_9]] into %[[VAL_10]]{{\[}}%[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]]] : tensor<1x2x2x2xf64
6665
// CHECK: sparse_tensor.yield %[[VAL_11]] : tensor<1x2x2x2xf64
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// RUN: %{compile} | %{run} | FileCheck %s
21+
//
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
//
26+
// Do the same run, but now with direct IR generation and vectorization.
27+
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-buffer-initialization=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
28+
// RUN: %{compile} | %{run} | FileCheck %s
29+
//
30+
// Do the same run, but now with direct IR generation and VLA vectorization.
31+
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
32+
33+
#CSR = #sparse_tensor.encoding<{
34+
map = (d0, d1) -> (d0 : dense, d1 : compressed)
35+
}>
36+
37+
#CSC = #sparse_tensor.encoding<{
38+
map = (d0, d1) -> (d1 : dense, d0 : compressed)
39+
}>
40+
41+
#BSR = #sparse_tensor.encoding<{
42+
map = ( i, j ) ->
43+
( i floordiv 2 : dense,
44+
j floordiv 2 : compressed,
45+
i mod 2 : dense,
46+
j mod 2 : dense
47+
)
48+
}>
49+
50+
51+
//
52+
// Integration test that tests conversions between sparse tensors.
53+
//
54+
module {
55+
//
56+
// Output utilities.
57+
//
58+
func.func @dumpf64(%arg0: memref<?xf64>) {
59+
%c0 = arith.constant 0 : index
60+
%d0 = arith.constant -1.0 : f64
61+
%0 = vector.transfer_read %arg0[%c0], %d0: memref<?xf64>, vector<8xf64>
62+
vector.print %0 : vector<8xf64>
63+
return
64+
}
65+
66+
//
67+
// Main driver.
68+
//
69+
func.func @entry() {
70+
%c0 = arith.constant 0 : index
71+
%c1 = arith.constant 1 : index
72+
%c2 = arith.constant 2 : index
73+
74+
//
75+
// Initialize a 2-dim dense tensor.
76+
//
77+
%t = arith.constant dense<[
78+
[ 1.0, 2.0, 3.0, 4.0 ],
79+
[ 5.0, 6.0, 7.0, 8.0 ]
80+
]> : tensor<2x4xf64>
81+
82+
83+
%1 = sparse_tensor.convert %t : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
84+
%2 = sparse_tensor.convert %1 : tensor<2x4xf64, #CSR> to tensor<2x4xf64, #BSR>
85+
%3 = sparse_tensor.convert %2 : tensor<2x4xf64, #BSR> to tensor<2x4xf64, #CSC>
86+
87+
%v1 = sparse_tensor.values %1 : tensor<2x4xf64, #CSR> to memref<?xf64>
88+
%v2 = sparse_tensor.values %2 : tensor<2x4xf64, #BSR> to memref<?xf64>
89+
%v3 = sparse_tensor.values %3 : tensor<2x4xf64, #CSC> to memref<?xf64>
90+
91+
// CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 )
92+
// CHECK-NEXT: ( 1, 2, 5, 6, 3, 4, 7, 8 )
93+
// CHECK-NEXT: ( 1, 5, 2, 6, 3, 7, 4, 8 )
94+
call @dumpf64(%v1) : (memref<?xf64>) -> ()
95+
call @dumpf64(%v2) : (memref<?xf64>) -> ()
96+
call @dumpf64(%v3) : (memref<?xf64>) -> ()
97+
98+
return
99+
}
100+
}

0 commit comments

Comments
 (0)