Skip to content

Commit e7b4c93

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix crash when using sparse_tensor::UnaryOp and ReduceOp.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D152048
1 parent 5a4e344 commit e7b4c93

File tree

4 files changed

+114
-46
lines changed

4 files changed

+114
-46
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
232232
this->hasOutput = hasOutput;
233233
this->isSparseOut = isSparseOut;
234234

235-
const unsigned numTensors = ts.size();
235+
const unsigned numManifestTensors = ts.size();
236+
const unsigned synTensorId = numManifestTensors;
237+
const unsigned numTensors = numManifestTensors + 1;
238+
236239
this->tensors.assign(ts.begin(), ts.end());
237240
this->lvlTypes.assign(numTensors, std::vector<DimLevelType>());
238241
this->lvlSizes.assign(numTensors, std::vector<Value>());
@@ -265,33 +268,43 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
265268

266269
// Initialize nested types of `TensorId`-indexed fields.
267270
for (TensorId tid = 0; tid < numTensors; tid++) {
268-
const Value t = tensors[tid];
269-
// a scalar or 0-dimension tensors
270-
if (isZeroRankedTensorOrScalar(t.getType()))
271-
continue;
272-
273-
auto rtp = getRankedTensorType(t);
274-
if (auto reshape = t.getDefiningOp<tensor::CollapseShapeOp>();
275-
isUniqueCOOType(rtp) && reshape) {
276-
// TODO: Supports more kinds of sparse tensors.
277-
// FIXME: We should instead lower reshape operations on sparse tensors to
278-
// view change.
279-
collapseReassoc[tid] = reshape.getReassociation();
280-
rtp = reshape.getSrcType();
281-
// Overwrites the tensor to the source tensor of reshape operations.
282-
tensors[tid] = reshape.getSrc();
283-
}
284-
const SparseTensorType stt(rtp);
285-
const Level lvlRank = stt.getLvlRank();
286-
// We always treat sparse output tensor as dense so that we always iterate
287-
// it based on lvl size.
288-
if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
289-
const auto enc = stt.getEncoding();
290-
isSparseSlices[tid] = enc.isSlice();
291-
for (auto lvlTp : enc.getLvlTypes())
292-
lvlTypes[tid].push_back(lvlTp);
293-
} else {
271+
Level lvlRank;
272+
if (tid == synTensorId) {
273+
// Synthetic tensor (conceptually) is an all-dense tensor with rank equal
274+
// to the total number of loops (each level can potentially be mapped to
275+
// one of the loop being generated).
276+
lvlRank = numLoops;
294277
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
278+
} else {
279+
const Value t = tensors[tid];
280+
// a scalar or 0-dimension tensors
281+
if (isZeroRankedTensorOrScalar(t.getType()))
282+
continue;
283+
284+
auto rtp = getRankedTensorType(t);
285+
if (auto reshape = t.getDefiningOp<tensor::CollapseShapeOp>();
286+
isUniqueCOOType(rtp) && reshape) {
287+
// TODO: Supports more kinds of sparse tensors.
288+
// FIXME: We should instead lower reshape operations on sparse tensors
289+
// to view change.
290+
collapseReassoc[tid] = reshape.getReassociation();
291+
rtp = reshape.getSrcType();
292+
// Overwrites the tensor to the source tensor of reshape operations.
293+
tensors[tid] = reshape.getSrc();
294+
}
295+
const SparseTensorType stt(rtp);
296+
lvlRank = stt.getLvlRank();
297+
298+
// We always treat sparse output tensor as dense so that we always iterate
299+
// it based on lvl size.
300+
if (stt.hasEncoding() && !(isOutputTensor(tid) && isSparseOut)) {
301+
const auto enc = stt.getEncoding();
302+
isSparseSlices[tid] = enc.isSlice();
303+
for (auto lvlTp : enc.getLvlTypes())
304+
lvlTypes[tid].push_back(lvlTp);
305+
} else {
306+
lvlTypes[tid].assign(lvlRank, DimLevelType::Dense);
307+
}
295308
}
296309

297310
// Initialize using empty value.
@@ -314,7 +327,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
314327
sliceStack[tid].emplace_back(/*minCrd=*/Value(),
315328
/*offset=*/Value(), /*isNonEmpty*/ Value(),
316329
std::nullopt, 0);
317-
if (dimGetter) {
330+
if (dimGetter && !isSynTensor(tid)) {
318331
auto reassoc = collapseReassoc[tid];
319332
Level dstRank = reassoc ? reassoc.size() : lvlRank;
320333
for (Level l = 0; l < dstRank; l++) {
@@ -461,15 +474,28 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
461474
assert(loopSeqStack.size() == loopStack.size());
462475
// Prepares for all the tensors used in the current loop sequence.
463476
std::vector<std::tuple<TensorId, Level, bool>> slicedTids;
477+
478+
bool hasSynTensor = false;
479+
std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
464480
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
465481
if (!dependentLvlMap[tid][lvl].empty()) {
466482
bool fullyRed = genSliceBegin(builder, loc, tid, lvl);
467483
slicedTids.emplace_back(tid, lvl, fullyRed);
468484
} else {
469-
prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
485+
if (isSynTensor(tid)) {
486+
hasSynTensor = true;
487+
} else {
488+
loopBoundDefLevel = std::make_pair(tid, lvl);
489+
prepareLoopOverTensorAtLvl(builder, loc, tid, lvl);
490+
}
470491
}
471492
}
472493

494+
if (hasSynTensor && loopBoundDefLevel.has_value()) {
495+
// TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
496+
highs[getSynTensorId()][getCurrentDepth()] =
497+
lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second];
498+
}
473499
// Universal Index starts from 0.
474500
loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids));
475501
}
@@ -1137,6 +1163,9 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
11371163
// output tensor unconditionally, since they may not appear in the lattice,
11381164
// but may be needed for linearized codegen.
11391165
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
1166+
if (isSynTensor(tid))
1167+
continue;
1168+
11401169
if (isDenseDLT(lvlTypes[tid][lvl])) {
11411170
// Slice-driven dense level should have be handled already.
11421171
if (!dependentLvlMap[tid][lvl].empty())

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,18 @@ class LoopEmitter {
194194
/// Gets the total number of tensors that loopEmitter is operating on.
195195
unsigned getNumTensors() const { return tensors.size(); }
196196

197+
/// Gets the TensorId for synthetic tensor.
198+
TensorId getSynTensorId() const { return tensors.size(); }
199+
197200
/// Compresses a TensorId and Level into a TensorLevel.
198201
TensorLevel makeTensorLevel(TensorId t, Level l) const {
199-
return l * getNumTensors() + t;
202+
// TODO: getNumTensor() should include synthetic tensor.
203+
return l * (getNumTensors() + 1) + t;
200204
}
201205

202206
/// De-compresses a TensorLevel back to a pair of TensorId and Level.
203207
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tidLvl) const {
204-
unsigned nt = getNumTensors();
208+
unsigned nt = getNumTensors() + 1;
205209
return std::make_pair(tidLvl % nt, tidLvl / nt);
206210
}
207211

@@ -319,6 +323,8 @@ class LoopEmitter {
319323
Location loc, Value crd,
320324
TensorId tid, Level lvl);
321325

326+
bool isSynTensor(TensorId tid) const { return tid == getNumTensors(); }
327+
322328
bool isOutputTensor(TensorId tid) const {
323329
return hasOutput && tid == getNumTensors() - 1;
324330
}
@@ -408,9 +414,11 @@ class LoopEmitter {
408414
/// TODO: why not do this computation when we first store the reassoc,
409415
/// instead of doing it every time we look it up?
410416
SmallVector<Level, 2> getCollapseReassociation(TensorId tid, Level dstLvl) {
411-
assert(tid < getNumTensors() && "Invalid TensorId");
412-
assert(collapseReassoc.size() == getNumTensors());
417+
assert(tid < getNumTensors() + 1 && "Invalid TensorId");
418+
assert(collapseReassoc.size() == getNumTensors() + 1);
413419
if (const auto reassoc = collapseReassoc[tid]) {
420+
assert(!isSynTensor(tid) && !isOutputTensor(tid) &&
421+
"Output/Synthetic tensor should not have reassociation");
414422
// TODO: store the dstLvlRank in the LoopEmitter so that we can
415423
// check `dstLvl < dstLvlRank` at the top; and only here need to
416424
// assert that `reassoc.size() == dstLvlRank`.

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,8 +1490,15 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
14901490
std::optional<Level> lvl,
14911491
DimLevelType dlt, bool isIdxReduc) {
14921492
assert(env.merger().loop(b) == idx);
1493-
if (isDenseDLT(dlt) || isUndefDLT(dlt))
1493+
if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
1494+
if (tid == env.merger().getSynTensorID()) {
1495+
// Needs loop emitter to set up loop bounds for synthetic tensor too if
1496+
// there is a loop condition imposed on the synthetic tensor.
1497+
tidLvls.push_back(
1498+
env.makeTensorLevel(tid, env.emitter().getCurrentDepth()));
1499+
}
14941500
needsUniv = true;
1501+
}
14951502
if (isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
14961503
isCompressedWithHiDLT(dlt) || isIdxReduc) {
14971504
// Only when this is a index reduction loop, can the dlt be undefined.
@@ -1575,13 +1582,24 @@ static bool translateBitsToTidLvlPairs(
15751582
// iterate based on the level of output tensor. E.g., this
15761583
// could be a synthetic tensor (for invariants and sparse
15771584
// output tensor).
1578-
// out[i][j] = invariant; or a broadcast
1579-
// out[i][j] = in[i] (j is undef for input)
1580-
tid = outTid;
1581-
lvl = outLvl;
1582-
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
1583-
if (!lvl)
1584-
return;
1585+
if (env.isReduc() && env.merger().getSynTensorID() == tid) {
1586+
// Coiterating with an invariant, and this is a reduction loop
1587+
// e.g., out = prod(in[i][j] op invariant);
1588+
// In this case, we can not infer the loop bound from output
1589+
// (whose level is reduced). Instead we use the synthetic tensor
1590+
// to infer the bound.
1591+
// The level of the synthetic tensor is the current loop depth;
1592+
// the rank of the synthetic tensor equals to number of loops.
1593+
lvl = env.emitter().getCurrentDepth();
1594+
} else {
1595+
// or a broadcast
1596+
// out[i][j] = in[i] (j is undef for input)
1597+
tid = outTid;
1598+
lvl = outLvl;
1599+
// Skips invalid lvl (e.g., when this is a zero ranked tensor).
1600+
if (!lvl)
1601+
return;
1602+
}
15851603
}
15861604
hasNonUnique = !isUniqueDLT(dlt) || hasNonUnique;
15871605
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1671,7 +1689,8 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
16711689
auto allTidLvls =
16721690
llvm::concat<TensorLevel>(tidLvls, llvm::make_first_range(affineTidLvls));
16731691
for (auto [tid, lvl] : env.unpackTensorLevelRange(allTidLvls)) {
1674-
if (tid != env.merger().getOutTensorID())
1692+
if (tid != env.merger().getOutTensorID() &&
1693+
tid != env.merger().getSynTensorID())
16751694
genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1);
16761695
}
16771696

@@ -1798,7 +1817,7 @@ static void genResult(CodegenEnv &env, RewriterBase &rewriter) {
17981817
} else {
17991818
// To rematerialize an non-annotated tensor, simply load it
18001819
// from the bufferized value.
1801-
Value val = env.emitter().getValBuffer().back(); // value array
1820+
Value val = env.emitter().getValBuffer()[env.merger().getOutTensorID()];
18021821
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val);
18031822
}
18041823
}

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_prod.mlir

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ module {
140140
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 4.0
141141
]> : tensor<32xf32>
142142

143-
// Convert constants to annotated tensors.
143+
// Convert constants to annotated tensors. Note that this
144+
// particular conversion only stores nonzero elements,
145+
// so we will have no explicit zeros, only implicit zeros.
144146
%d0_i32 = sparse_tensor.convert %c_0_i32
145147
: tensor<32xi32> to tensor<32xi32, #DV>
146148
%d0_f32 = sparse_tensor.convert %c_0_f32
@@ -158,6 +160,10 @@ module {
158160
%s1_f32 = sparse_tensor.convert %c_1_f32
159161
: tensor<32xf32> to tensor<32xf32, #SV>
160162

163+
// Special case, construct a sparse vector with an explicit zero.
164+
%v0 = arith.constant sparse< [ [1] ], [ 0 ] > : tensor<32xi32>
165+
%s0 = sparse_tensor.convert %v0: tensor<32xi32> to tensor<32xi32, #SV>
166+
161167
// Call the kernels.
162168
%0 = call @prod_dreduction_i32(%d0_i32, %ri) : (tensor<32xi32, #DV>, tensor<i32>) -> tensor<i32>
163169
%1 = call @prod_dreduction_f32(%d0_f32, %rf) : (tensor<32xf32, #DV>, tensor<f32>) -> tensor<f32>
@@ -167,19 +173,23 @@ module {
167173
%5 = call @prod_dreduction_f32(%d1_f32, %rf) : (tensor<32xf32, #DV>, tensor<f32>) -> tensor<f32>
168174
%6 = call @prod_sreduction_i32(%s1_i32, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
169175
%7 = call @prod_sreduction_f32(%s1_f32, %rf) : (tensor<32xf32, #SV>, tensor<f32>) -> tensor<f32>
176+
%8 = call @prod_sreduction_i32(%s0, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
170177

171178
// Verify results. Note that the custom reduction gave permission
172179
// to treat an explicit vs implicit zero differently to compute the
173-
// full product reduction. A "standard" product reduction would
174-
// have to return 0 for any implicit zero occurrence too.
180+
// full product reduction over stored elements. A "standard" product
181+
// reduction would have to return 0 for any implicit zero occurrence
182+
// too. An explicit zero nullifies the product, though, as requested.
175183
//
176184
// CHECK: 0
185+
// CHECK: 0
177186
// CHECK: 3087
178187
// CHECK: 14
179188
// CHECK: 3087
180189
// CHECK: 168
181190
// CHECK: 3087
182191
// CHECK: 168
192+
// CHECK: 0
183193
//
184194
call @dump_i32(%0) : (tensor<i32>) -> ()
185195
call @dump_f32(%1) : (tensor<f32>) -> ()
@@ -189,6 +199,7 @@ module {
189199
call @dump_f32(%5) : (tensor<f32>) -> ()
190200
call @dump_i32(%6) : (tensor<i32>) -> ()
191201
call @dump_f32(%7) : (tensor<f32>) -> ()
202+
call @dump_i32(%8) : (tensor<i32>) -> ()
192203

193204
// Release the resources.
194205
bufferization.dealloc_tensor %d0_i32 : tensor<32xi32, #DV>
@@ -199,6 +210,7 @@ module {
199210
bufferization.dealloc_tensor %d1_f32 : tensor<32xf32, #DV>
200211
bufferization.dealloc_tensor %s1_i32 : tensor<32xi32, #SV>
201212
bufferization.dealloc_tensor %s1_f32 : tensor<32xf32, #SV>
213+
bufferization.dealloc_tensor %s0 : tensor<32xi32, #SV>
202214

203215
return
204216
}

0 commit comments

Comments
 (0)