Skip to content

Commit 83b7f01

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix crashes when the tensor that defines the loop bound can not be found
Reviewed By: aartbik, K-Wu Differential Revision: https://reviews.llvm.org/D152877
1 parent e3cc8f3 commit 83b7f01

File tree

4 files changed

+112
-19
lines changed

4 files changed

+112
-19
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,15 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
376376
loopIdToOrd[topSort[n]] = n;
377377
}
378378

379-
void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
380-
LoopEmitter::OutputUpdater updater) {
379+
void LoopEmitter::initializeLoopEmit(
380+
OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater,
381+
LoopEmitter::SynTensorBoundSetter synSetter) {
382+
383+
// For every synthetic tensor, set the high bound by calling the callback.
384+
if (synSetter)
385+
for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++)
386+
highs[getSynTensorId()][i] = synSetter(builder, loc, i);
387+
381388
// For every manifest tensor:
382389
// * get the values buffer.
383390
// * For every level:
@@ -534,27 +541,15 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
534541
// Prepares for all the tensors used in the current loop sequence.
535542
std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
536543

537-
bool hasSynTensor = false;
538-
std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
539544
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
540545
if (!dependentLvlMap[tid][lvl].empty()) {
541546
bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
542547
slicedTids.emplace_back(tid, lvl, fullyRed);
543-
} else {
544-
if (isSynTensor(tid)) {
545-
hasSynTensor = true;
546-
} else {
547-
loopBoundDefLevel = std::make_pair(tid, lvl);
548-
prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
549-
}
548+
} else if (!isSynTensor(tid)) {
549+
prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
550550
}
551551
}
552552

553-
if (hasSynTensor && loopBoundDefLevel.has_value()) {
554-
// TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
555-
highs[getSynTensorId()][getCurrentDepth()] =
556-
lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second];
557-
}
558553
// Universal Index starts from 0.
559554
loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
560555
}

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ class LoopEmitter {
7878
/// initializing the loop emitter (e.g., to fill a dense output with zeros).
7979
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
8080
Value memref, Value tensor)>;
81+
82+
/// Optional callback function to set the bound for the synthetic tensor,
83+
/// which essentially is the dense loop bound.
84+
using SynTensorBoundSetter =
85+
function_ref<Value(OpBuilder &builder, Location loc, Level lvl)>;
86+
8187
// Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
8288
// index on sparse tensors.
8389
// E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
@@ -114,7 +120,8 @@ class LoopEmitter {
114120
/// Starts a loop emitting session by generating all the buffers needed
115121
/// for iterating over the tensors.
116122
void initializeLoopEmit(OpBuilder &builder, Location loc,
117-
OutputUpdater updater = nullptr);
123+
OutputUpdater updater = nullptr,
124+
SynTensorBoundSetter synSetter = nullptr);
118125

119126
/// Generates code to compute an affine expression whose variables are
120127
/// `LoopId`s (i.e., `a.cast<AffineDimExpr>().getPosition()` is a valid

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,21 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
832832
Location loc = op.getLoc();
833833
assert(op.getNumOperands() == op.getNumDpsInputs() + 1);
834834

835+
SmallVector<Range, 4> loopRange =
836+
llvm::cast<linalg::LinalgOp>(op.getOperation())
837+
.createLoopRanges(builder, loc);
838+
839+
assert(loopRange.size() == env.merger().getStartingFilterLoopId());
840+
SmallVector<Range, 4> sortedRange;
841+
for (unsigned i = 0, e = env.topSortSize(); i < e; i++) {
842+
LoopId ldx = env.topSortAt(i);
843+
// FIXME: Gets rid of filter loops since we have a better algorithm to deal
844+
// with affine index expression.
845+
if (ldx < env.merger().getStartingFilterLoopId()) {
846+
sortedRange.push_back(loopRange[ldx]);
847+
}
848+
}
849+
835850
env.emitter().initializeLoopEmit(
836851
builder, loc,
837852
/// Generates buffer for the output tensor.
@@ -865,6 +880,16 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
865880
ValueRange{init});
866881
}
867882
return init;
883+
},
884+
[&sortedRange, &env](OpBuilder &b, Location loc, Level l) {
885+
assert(l < env.topSortSize());
886+
// FIXME: Remove filter loop since we have a better algorithm to
887+
// deal with affine index expression.
888+
if (l >= env.merger().getStartingFilterLoopId())
889+
return Value();
890+
891+
return mlir::getValueOrCreateConstantIndexOp(b, loc,
892+
sortedRange[l].size);
868893
});
869894
}
870895

@@ -1594,7 +1619,9 @@ static bool translateBitsToTidLvlPairs(
15941619
// iterate based on the level of output tensor. E.g., this
15951620
// could be a synthetic tensor (for invariants and sparse
15961621
// output tensor).
1597-
if (env.isReduc() && env.merger().getSynTensorID() == tid) {
1622+
auto itType = env.op().getIteratorTypesArray()[ldx];
1623+
if (linalg::isReductionIterator(itType) &&
1624+
env.merger().getSynTensorID() == tid) {
15981625
// Coiterating with an invariant, and this is a reduction loop
15991626
// e.g., out = prod(in[i][j] op invariant);
16001627
// In this case, we can not infer the loop bound from output
@@ -1669,7 +1696,14 @@ static bool translateBitsToTidLvlPairs(
16691696
tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
16701697
}
16711698

1672-
assert(numloopCond > 0);
1699+
if (numloopCond == 0) {
1700+
// Corner cases where the loop bound is defined by a *unused* operand, in
1701+
// this case, we just generate a dense "fake" loop by iterating over the
1702+
// synthetic tensor.
1703+
tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
1704+
env.emitter().getCurrentDepth()));
1705+
numloopCond++;
1706+
}
16731707
// If we just need to one loop conditions and the conditions is not imposed on
16741708
// non-unique level, the loop can be generated by a for loop.
16751709
return numloopCond == 1 && !hasNonUnique;
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -sparsification | FileCheck %s
2+
3+
//
4+
// A contrived example where the sparse tensor B is only
5+
// used in the linalg op to determine the number of iterations
6+
// for the k-loop. This is included to make sure the sparse
7+
// compiler still generates the correct loop nest for this case.
8+
//
9+
10+
#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
11+
12+
#trait = {
13+
indexing_maps = [
14+
affine_map<(i,j,k) -> (i,j)>, // A
15+
affine_map<(i,j,k) -> (k,j)>, // B
16+
affine_map<(i,j,k) -> (i,j)> // S_out
17+
],
18+
iterator_types = ["parallel", "parallel", "reduction"],
19+
doc = "C(i,j) = SUM_k A(i,j)"
20+
}
21+
22+
// CHECK-LABEL: func.func @b_ununsed(
23+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64>,
24+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{.*}}>>,
25+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64> {
26+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
27+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
28+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : index
29+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
30+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
31+
// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<2x4xf64>
32+
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2x4xf64>
33+
// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
34+
// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
35+
// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
36+
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
37+
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
38+
// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f64
39+
// CHECK: memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64>
40+
// CHECK: }
41+
// CHECK: }
42+
// CHECK: }
43+
// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<2x4xf64>
44+
// CHECK: return %[[VAL_16]] : tensor<2x4xf64>
45+
// CHECK: }
46+
func.func @b_ununsed(%argA: tensor<2x4xf64>,
47+
%argB: tensor<8x4xf64, #SM>,
48+
%argC: tensor<2x4xf64>) -> tensor<2x4xf64> {
49+
%result = linalg.generic #trait
50+
ins(%argA, %argB: tensor<2x4xf64>, tensor<8x4xf64, #SM>)
51+
outs(%argC: tensor<2x4xf64>) {
52+
^bb(%a: f64, %b: f64, %c: f64):
53+
%0 = arith.addf %c, %a : f64
54+
linalg.yield %0 : f64
55+
} -> tensor<2x4xf64>
56+
return %result : tensor<2x4xf64>
57+
}

0 commit comments

Comments
 (0)