Skip to content

Commit 6db397a

Browse files
author
Peiming Liu
committed
[mlir][sparse] support dynamic sparse tensor slices.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D141532
1 parent 8a712bf commit 6db397a

File tree

14 files changed

+450
-112
lines changed

14 files changed

+450
-112
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
570570
/// We normalized sparse tensor encoding attribute by always using
571571
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
572572
/// as other variants) lead to the same storage specifier type, and stripping
573-
/// irrelevant fields that does not alter the sparse tensor memory layout.
573+
/// irrelevant fields that do not alter the sparse tensor memory layout.
574574
static SparseTensorEncodingAttr
575575
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
576576
SmallVector<DimLevelType> dlts;
@@ -582,13 +582,10 @@ getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
582582
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
583583
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
584584
// Always use `index` for memSize and lvlSize instead of reusing
585-
// `getPosWidth`/`getCrdWidth`.
586-
// It allows us to reuse the same SSA value for different bitwidth,
587-
// It also avoids casting between index/integer (returned by DimOp)
588-
0, 0,
589-
// FIXME: we should keep the slice information, for now it is okay as only
590-
// constant can be used for slice
591-
ArrayRef<SparseTensorDimSliceAttr>{} /*enc.getDimSlices()*/);
585+
// `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
586+
// value for different bitwidth, it also avoids casting between index and
587+
// integer (returned by DimOp)
588+
0, 0, enc.getDimSlices());
592589
}
593590

594591
StorageSpecifierType
@@ -620,11 +617,10 @@ static LogicalResult verifySparsifierGetterSetter(
620617
const auto enc = md.getType().getEncoding();
621618
const Level lvlRank = enc.getLvlRank();
622619

623-
// TODO:
624-
// if (mdKind == StorageSpecifierKind::DimOffset ||
625-
// mdKind == StorageSpecifierKind::DimStride)
626-
// if (!enc.isSlice())
627-
// return op->emitError("requested slice data on non-slice tensor");
620+
if (mdKind == StorageSpecifierKind::DimOffset ||
621+
mdKind == StorageSpecifierKind::DimStride)
622+
if (!enc.isSlice())
623+
return op->emitError("requested slice data on non-slice tensor");
628624

629625
if (mdKind != StorageSpecifierKind::ValMemSize) {
630626
if (!lvl)

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,3 +694,23 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
694694
Value tensor) {
695695
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
696696
}
697+
698+
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
699+
Value tensor, Dimension dim) {
700+
auto enc = getSparseTensorEncoding(tensor.getType());
701+
assert(enc && enc.isSlice());
702+
std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
703+
if (offset.has_value())
704+
return constantIndex(builder, loc, *offset);
705+
return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
706+
}
707+
708+
Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
709+
Value tensor, Dimension dim) {
710+
auto enc = getSparseTensorEncoding(tensor.getType());
711+
assert(enc && enc.isSlice());
712+
std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
713+
if (stride.has_value())
714+
return constantIndex(builder, loc, *stride);
715+
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
716+
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,15 @@ Value genToValues(OpBuilder &builder, Location loc, Value tensor);
364364
/// Generates code to retrieve the values size for the sparse tensor.
365365
Value genValMemSize(OpBuilder &builder, Location loc, Value tensor);
366366

367+
/// Generates code to retrieve the slice offset for the sparse tensor slice,
368+
/// return a constant if the offset is statically known.
369+
Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
370+
Dimension dim);
371+
372+
/// Generates code to retrieve the slice slice for the sparse tensor slice,
373+
/// return a constant if the offset is statically known.
374+
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
375+
Dimension dim);
367376
} // namespace sparse_tensor
368377
} // namespace mlir
369378

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 79 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,25 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
4343
return load;
4444
}
4545

46-
// TODO: Support dynamic sized slice.
47-
static Value getSliceOffset(OpBuilder &builder, Location loc,
48-
SparseTensorEncodingAttr enc, unsigned lvl) {
49-
return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl));
46+
static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor,
47+
unsigned lvl) {
48+
auto enc = getSparseTensorEncoding(tensor.getType());
49+
// FIXME: `toOrigDim` is deprecated
50+
return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl));
5051
}
5152

52-
static Value getSliceSize(OpBuilder &builder, Location loc,
53-
SparseTensorEncodingAttr enc, unsigned lvl) {
54-
return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl));
55-
}
56-
57-
static Value getSliceStride(OpBuilder &builder, Location loc,
58-
SparseTensorEncodingAttr enc, unsigned lvl) {
59-
return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl));
53+
static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
54+
unsigned lvl) {
55+
auto enc = getSparseTensorEncoding(tensor.getType());
56+
// FIXME: `toOrigDim` is deprecated
57+
return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl));
6058
}
6159

6260
// Converts a coordinate relative to the slice to the coordinate relative
6361
// to the underlying tensor.
6462
static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
65-
SparseTensorEncodingAttr enc, unsigned lvl) {
66-
67-
Value stride = getSliceStride(builder, loc, enc, lvl);
68-
Value offset = getSliceOffset(builder, loc, enc, lvl);
63+
Value offset, Value stride, Value tensor,
64+
unsigned lvl) {
6965
// iv = iv * stride + offset
7066
v = builder.create<arith::MulIOp>(loc, v, stride);
7167
v = builder.create<arith::AddIOp>(loc, v, offset);
@@ -75,40 +71,58 @@ static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
7571
// Converts a coordinate relative to the underlying tensor to the coordinate
7672
// relative to the slice, returns a extra reminder value
7773
static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
78-
Value v,
79-
SparseTensorEncodingAttr enc,
74+
Value iv, Value offset,
75+
Value stride, Value tensor,
8076
unsigned lvl) {
81-
Value stride = getSliceStride(builder, loc, enc, lvl);
82-
Value offset = getSliceOffset(builder, loc, enc, lvl);
8377
// iv = (iv - offset) / stride
84-
v = builder.create<arith::SubIOp>(loc, v, offset);
85-
Value rem = builder.create<arith::RemUIOp>(loc, v, stride);
86-
v = builder.create<arith::DivUIOp>(loc, v, stride);
87-
return std::make_pair(v, rem);
78+
iv = builder.create<arith::SubIOp>(loc, iv, offset);
79+
Value rem = builder.create<arith::RemUIOp>(loc, iv, stride);
80+
iv = builder.create<arith::DivUIOp>(loc, iv, stride);
81+
return std::make_pair(iv, rem);
8882
}
8983

90-
static std::pair<Value, Value>
91-
genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
92-
SparseTensorEncodingAttr enc, unsigned lvl) {
93-
std::pair<Value, Value> trans = fromSliceCrd(builder, loc, crd, enc, lvl);
94-
// First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip
95-
// the check if the offset is zero).
96-
auto geOffset =
97-
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, crd,
98-
getSliceOffset(builder, loc, enc, lvl));
84+
std::pair<Value, Value>
85+
LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd,
86+
unsigned tid, unsigned lvl) {
87+
assert(isSparseSlices[tid]);
88+
Value slice = tensors[tid];
89+
Value offset = sliceOffsets[tid][lvl];
90+
Value stride = sliceStrides[tid][lvl];
91+
auto enc = getSparseTensorEncoding(slice.getType());
92+
93+
std::pair<Value, Value> transformedCrd =
94+
fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl);
95+
96+
SmallVector<Value, 3> conds; // at most 3 conditions
97+
98+
// First, coord >= offset (skip the check if offset is known to be 0).
99+
if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl);
100+
!(staticOffset.has_value() && *staticOffset == 0)) {
101+
auto geOffset = builder.create<arith::CmpIOp>(
102+
loc, arith::CmpIPredicate::uge, crd, offset);
103+
conds.push_back(geOffset);
104+
}
105+
99106
// Second, coord_in_slice < length
100-
auto ltLength =
101-
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first,
102-
getSliceSize(builder, loc, enc, lvl));
103-
104-
// Third, rem == 0; confirmed that (a % 1) will be folded to 0
105-
auto fitStride =
106-
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second,
107-
constantIndex(builder, loc, 0));
108-
109-
auto pred = builder.create<arith::AndIOp>(loc, geOffset, ltLength);
110-
pred = builder.create<arith::AndIOp>(loc, pred, fitStride);
111-
return {trans.first, pred};
107+
auto ltLength = builder.create<arith::CmpIOp>(
108+
loc, arith::CmpIPredicate::ult, transformedCrd.first, lvlSizes[tid][lvl]);
109+
conds.push_back(ltLength);
110+
111+
// Third, rem == 0 (skip the check if stride is known to be 1).
112+
if (auto staticStride = enc.getStaticLvlSliceStride(lvl);
113+
!(staticStride.has_value() && *staticStride == 1)) {
114+
auto fitStride = builder.create<arith::CmpIOp>(
115+
loc, arith::CmpIPredicate::eq, transformedCrd.second,
116+
constantIndex(builder, loc, 0));
117+
conds.push_back(fitStride);
118+
}
119+
120+
// Must meet all condition to be a valid coordinate in slice.
121+
auto pred = conds.front();
122+
for (auto cond : ValueRange(conds).drop_front())
123+
pred = builder.create<arith::AndIOp>(loc, pred, cond);
124+
125+
return {transformedCrd.first, pred};
112126
}
113127

114128
//===----------------------------------------------------------------------===//
@@ -119,10 +133,9 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
119133
size_t dim, Value iv) {
120134
Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1];
121135
Value mul = builder.create<arith::MulIOp>(loc, highs[tid][dim], p);
122-
if (isSparseSlices[tid]) {
123-
auto enc = getSparseTensorEncoding(tensors[tid].getType());
124-
iv = toSliceCoord(builder, loc, iv, enc, dim);
125-
}
136+
if (isSparseSlices[tid])
137+
iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim],
138+
sliceStrides[tid][dim], tensors[tid], dim);
126139
Value add = builder.create<arith::AddIOp>(loc, mul, iv);
127140
return add;
128141
}
@@ -204,6 +217,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
204217
this->isSparseOut = isSparseOut;
205218
this->tensors.assign(ts.begin(), ts.end());
206219
this->isSparseSlices.assign(tensors.size(), false);
220+
this->sliceOffsets.assign(tensors.size(), std::vector<Value>());
221+
this->sliceStrides.assign(tensors.size(), std::vector<Value>());
207222
this->dimTypes.assign(tensors.size(), std::vector<DimLevelType>());
208223
this->pidxs.assign(tensors.size(), std::vector<Value>());
209224
this->segHi.assign(tensors.size(), std::vector<Value>());
@@ -246,6 +261,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
246261
dimTypes[tid].assign(rank, DimLevelType::Dense);
247262

248263
// Initialize using empty value.
264+
sliceOffsets[tid].assign(rank, Value());
265+
sliceStrides[tid].assign(rank, Value());
249266
pidxs[tid].assign(rank, Value());
250267
segHi[tid].assign(rank, Value());
251268
coord[tid].assign(rank, Value());
@@ -300,11 +317,17 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
300317
assert(isDenseDLT(dlt));
301318
}
302319

303-
// Find upper bound in current dimension.
304320
// FIXME: `toOrigDim` is deprecated
305-
const Dimension d = toOrigDim(enc, l);
306-
lvlSizes[t][l] = highs[t][l] =
307-
mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d);
321+
// Since we do not have HigherOrdering now, we can always rely on the 1:1
322+
// mapping from level to dimension to retrieve the level size.
323+
Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor,
324+
toOrigDim(enc, l));
325+
// Find upper bound in current dimension.
326+
highs[t][l] = lvlSizes[t][l] = lvlSz;
327+
if (isSparseSlices[t]) {
328+
sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l);
329+
sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l);
330+
}
308331
}
309332

310333
// Perform the required bufferization. Dense inputs materialize
@@ -405,7 +428,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
405428
isSparseInput = isSparseInput || isSparse;
406429
}
407430

408-
auto enc = getSparseTensorEncoding(tensors[tid].getType());
409431
const auto reassoc = getCollapseReassociation(tid, dim);
410432
// TODO: support dynamic slices.
411433
// Uses the first dimension here to build the loop bound (which is also the
@@ -468,7 +490,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
468490
for (Value red : reduc)
469491
types.push_back(red.getType());
470492

471-
auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim);
493+
auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim);
472494
bool hasReduc = !types.empty();
473495
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, pred,
474496
/*else*/ hasReduc);
@@ -660,11 +682,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
660682
isSingletonDLT(dimTypes[tid][dim])) {
661683
coord[tid][dim] = genSparseCrd(builder, loc, tid, dim);
662684
if (isSparseSlices[tid]) {
663-
Value load =
664-
genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]);
665-
auto enc = getSparseTensorEncoding(tensors[tid].getType());
666685
auto [trans, pred] =
667-
genSliceLegitPredicate(builder, loc, load, enc, dim);
686+
genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim);
668687
slicesPreds.emplace_back(pred, i);
669688
// Updates to the relative coordinate to the slice.
670689
coord[tid][dim] = trans;
@@ -679,7 +698,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
679698
// Generates a list of if statments
680699
// pidx = in_slice ? pidx : pidx + 1
681700
// TODO: instead of always picking pidx + 1, we should set pidx = high to
682-
// break to loop the coordinates is larger than the slice size.
701+
// break to loop if the coordinates is larger than the slice size.
683702
for (auto [pred, idx] : slicesPreds) {
684703
Value nextPidx = builder.create<arith::AddIOp>(
685704
loc, yields[idx], constantIndex(builder, loc, 1));

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ class LoopEmitter {
202202
Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid,
203203
size_t dstLvl);
204204

205+
/// Generates a predicate to determine whether the tranformed coordinates are
206+
/// in the given slice.
207+
/// Returns std::pair<Transformed coordinates, Predicate>
208+
std::pair<Value, Value> genSliceLegitPredicate(OpBuilder &builder,
209+
Location loc, Value crd,
210+
unsigned tid, unsigned lvl);
211+
205212
bool isOutputTensor(size_t tid) {
206213
return hasOutput && tid == tensors.size() - 1;
207214
}
@@ -278,6 +285,9 @@ class LoopEmitter {
278285

279286
/// Whether the sparse input is a slice.
280287
std::vector<bool> isSparseSlices;
288+
/// Values related to slices.
289+
std::vector<std::vector<Value>> sliceOffsets;
290+
std::vector<std::vector<Value>> sliceStrides;
281291

282292
/// Loop Stack, stores the information of all the nested loops that are
283293
/// alive.

mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,18 @@ Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
130130
/// Builds IR extracting the pos-th offset from the descriptor.
131131
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
132132
Dimension dim) const {
133-
return builder.create<LLVM::ExtractValueOp>(
134-
loc, value,
135-
ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
133+
return extractField(
134+
builder, loc,
135+
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
136136
}
137137

138138
/// Builds IR inserting the pos-th offset into the descriptor.
139139
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
140140
Dimension dim, Value size) {
141-
value = builder.create<LLVM::InsertValueOp>(
142-
loc, value, size,
143-
ArrayRef<int64_t>({kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)}));
141+
insertField(
142+
builder, loc,
143+
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
144+
size);
144145
}
145146

146147
/// Builds IR extracting the `lvl`-th level-size from the descriptor.

0 commit comments

Comments
 (0)