Skip to content

Commit 015bc34

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix crash when sparsifying broadcast operations.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D136866
1 parent 8fa32a7 commit 015bc34

File tree

4 files changed

+70
-14
lines changed

4 files changed

+70
-14
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
9797
SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors,
9898
bool hasOutput,
9999
bool isSparseOut)
100-
: hasOutput(hasOutput), tensors(tensors.begin(), tensors.end()),
101-
dimTypes(tensors.size()), pidxs(tensors.size()), coord(tensors.size()),
102-
highs(tensors.size()), ptrBuffer(tensors.size()),
103-
idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack() {
100+
: hasOutput(hasOutput), isSparseOut(isSparseOut),
101+
tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()),
102+
pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()),
103+
ptrBuffer(tensors.size()), idxBuffer(tensors.size()),
104+
valBuffer(tensors.size()), loopStack() {
104105
for (size_t tid = 0, e = tensors.size(); tid < e; tid++) {
105106
auto t = tensors[tid];
106107
// a scalar or 0-dimension tensors
@@ -246,7 +247,7 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
246247
coord[tid][dim] = iv;
247248
// generate pidx for dense dim (pidx = i * sz + j)
248249
auto enc = getSparseTensorEncoding(tensors[tid].getType());
249-
if (enc)
250+
if (enc && !isSparseOutput(tid))
250251
pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv);
251252
}
252253

@@ -353,7 +354,7 @@ Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
353354
pidxs[tid][dim] = min;
354355
// generate pidx for dense dim (pidx = i * sz + j)
355356
auto enc = getSparseTensorEncoding(tensors[tid].getType());
356-
if (enc)
357+
if (enc && !isSparseOutput(tid))
357358
pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min);
358359
}
359360
// NOTE: we can also prepares for next dim here in advance
@@ -419,7 +420,7 @@ void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
419420
for (auto [tid, dim] : llvm::zip(tids, dims)) {
420421
assert(isDenseDLT(dimTypes[tid][dim]));
421422
auto enc = getSparseTensorEncoding(tensors[tid].getType());
422-
if (enc) {
423+
if (enc && !isSparseOutput(tid)) {
423424
bool validPidx = dim == 0 || pidxs[tid][dim - 1];
424425
if (!validPidx) {
425426
// We might not find the pidx for the sparse output tensor as it is

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,8 @@ class SparseTensorLoopEmitter {
434434
return hasOutput && tid == tensors.size() - 1;
435435
}
436436

437+
bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; }
438+
437439
/// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0
438440
/// ...dims-1] has already been setup.
439441
void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid,
@@ -462,6 +464,7 @@ class SparseTensorLoopEmitter {
462464
// Whether the loop emitter needs to treat the last tensor as the output
463465
// tensor.
464466
bool hasOutput;
467+
bool isSparseOut;
465468
/// Input and (optional) output tensors.
466469
std::vector<Value> tensors;
467470
/// The dim type array for each tensor.

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,13 +1130,13 @@ static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen,
11301130
assert(all.test(b));
11311131
assert(merger.index(b) == idx);
11321132
if (isUndefDLT(merger.getDimLevelType(b))) {
1133-
// This could be a synthetic tensor (for invariants and sparse output
1134-
// tensor).
1135-
// In both cases, we mean to generate loops over output tensor.
1136-
// e.g.,
1137-
// out[i][j] = invariant;
1138-
if (merger.getSynTensorID() == tid)
1139-
tid = merger.getOutTensorID();
1133+
// An undefined dlt in the lattices, we probably mean to iterate based
1134+
// on the dim of output tensor.
1135+
// E.g., this could be a synthetic tensor (for invariants and sparse
1136+
// output tensor).
1137+
// out[i][j] = invariant; or a broadcast
1138+
// out[i][j] = in[i] (j is undef for input)
1139+
tid = merger.getOutTensorID();
11401140
}
11411141
auto dim = codegen.loopIdxToDim[tid][idx];
11421142
if (dim != INVALID_ID) {
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s
2+
3+
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
4+
#SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }>
5+
6+
#trait = {
7+
indexing_maps = [
8+
affine_map<(d0, d1, d2) -> (d0, d2)>,
9+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>
10+
],
11+
iterator_types = ["parallel", "parallel", "parallel"]
12+
}
13+
14+
// CHECK-LABEL: @main(
15+
// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x5xi32,
16+
// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index
17+
// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index
18+
// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index
19+
// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor()
20+
// CHECK: %[[TMP_1:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index}
21+
// CHECK: %[[TMP_2:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index}
22+
// CHECK: %[[TMP_3:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index}
23+
// CHECK: %[[TMP_4:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index}
24+
// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]]
25+
// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref<?xindex>
26+
// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref<?xindex>
27+
// CHECK: scf.for %[[TMP_arg1:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] {
28+
// CHECK: %[[TMP_9:.*]] = memref.load %[[TMP_2]][%[[TMP_arg1]]] : memref<?xindex>
29+
// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c3]] step %[[TMP_c1]] {
30+
// CHECK: %[[TMP_10:.*]] = memref.load %[[TMP_3]][%[[TMP_arg1]]] : memref<?xindex>
31+
// CHECK: %[[TMP_11:.*]] = arith.addi %[[TMP_arg1]], %[[TMP_c1]] : index
32+
// CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_3]][%[[TMP_11]]] : memref<?xindex>
33+
// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_12]] step %[[TMP_c1]] {
34+
// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_4]][%[[TMP_arg3]]] : memref<?xindex>
35+
// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref<?xi32>
36+
// CHECK: %[[TMP_15:.*]] = sparse_tensor.insert %[[TMP_14]] into %[[TMP_0]][%[[TMP_9]], %[[TMP_arg2]], %[[TMP_13]]]
37+
// CHECK: }
38+
// CHECK: }
39+
// CHECK: }
40+
// CHECK: %[[TMP_8:.*]] = sparse_tensor.load %[[TMP_0]] hasInserts
41+
// CHECK: return %[[TMP_8]]
42+
module @func_sparse {
43+
func.func public @main(%arg0: tensor<4x5xi32, #DCSR>) -> tensor<4x3x5xi32, #SparseTensor> {
44+
%0 = bufferization.alloc_tensor() : tensor<4x3x5xi32, #SparseTensor>
45+
%1 = linalg.generic #trait
46+
ins(%arg0 : tensor<4x5xi32, #DCSR>) outs(%0 : tensor<4x3x5xi32, #SparseTensor>) {
47+
^bb0(%in: i32, %out: i32):
48+
linalg.yield %in : i32
49+
} -> tensor<4x3x5xi32, #SparseTensor>
50+
return %1 : tensor<4x3x5xi32, #SparseTensor>
51+
}
52+
}

0 commit comments

Comments
 (0)