Skip to content

Commit 160d483

Browse files
authored
[mlir][sparse] implement loose-compressed/2:4 on direct IR codegen path (llvm#71461)
Fills in the missing cases for direct IR codegen. Note that non-permutation handling is still TBD.
1 parent 16a395b commit 160d483

File tree

4 files changed

+144
-84
lines changed

4 files changed

+144
-84
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,8 @@ void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
633633
SmallVectorImpl<Value> &out) {
634634
out.clear();
635635
out.reserve(stt.getDimRank());
636-
for (const Size sh : stt.getDimShape()) {
637-
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
636+
for (const Size sz : stt.getDimShape()) {
637+
const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
638638
out.push_back(constantIndex(builder, loc, s));
639639
}
640640
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 78 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
#include "CodegenUtils.h"
1919
#include "SparseTensorDescriptor.h"
2020

21-
#include "llvm/Support/FormatVariadic.h"
22-
2321
#include "mlir/Dialect/Arith/Utils/Utils.h"
2422
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2523
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -116,31 +114,36 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
116114
const SparseTensorType stt(desc.getRankedTensorType());
117115
Value linear = constantIndex(builder, loc, 1);
118116
const Level lvlRank = stt.getLvlRank();
119-
for (Level l = startLvl; l < lvlRank; l++) {
120-
const auto dlt = stt.getLvlType(l);
121-
if (isCompressedDLT(dlt)) {
117+
for (Level lvl = startLvl; lvl < lvlRank; lvl++) {
118+
const auto dlt = stt.getLvlType(lvl);
119+
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt)) {
122120
// Append linear x positions, initialized to zero. Since each compressed
123121
// dimension initially already has a single zero entry, this maintains
124-
// the desired "linear + 1" length property at all times.
122+
// the desired "linear + 1" length property at all times. For loose
123+
// compression, we multiply linear by two in order to append both the
124+
// lo/hi positions.
125125
Value posZero = constantZero(builder, loc, stt.getPosType());
126-
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
127-
posZero, linear);
126+
if (isLooseCompressedDLT(dlt)) {
127+
Value two = constantIndex(builder, loc, 2);
128+
linear = builder.create<arith::MulIOp>(loc, linear, two);
129+
}
130+
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
131+
/*value=*/posZero, /*repeat=*/linear);
128132
return;
129-
}
130-
if (isSingletonDLT(dlt)) {
133+
} else if (isSingletonDLT(dlt) || is2OutOf4DLT(dlt)) {
131134
return; // nothing to do
132135
}
133136
// Keep compounding the size, but nothing needs to be initialized
134137
// at this level. We will eventually reach a compressed level or
135138
// otherwise the values array for the from-here "all-dense" case.
136139
assert(isDenseDLT(dlt));
137-
Value size = desc.getLvlSize(builder, loc, l);
140+
Value size = desc.getLvlSize(builder, loc, lvl);
138141
linear = builder.create<arith::MulIOp>(loc, linear, size);
139142
}
140143
// Reached values array so prepare for an insertion.
141144
Value valZero = constantZero(builder, loc, stt.getElementType());
142145
createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
143-
std::nullopt, valZero, linear);
146+
std::nullopt, /*value=*/valZero, /*repeat=*/linear);
144147
}
145148

146149
/// Creates allocation operation.
@@ -157,12 +160,9 @@ static Value createAllocation(OpBuilder &builder, Location loc,
157160
}
158161

159162
/// Creates allocation for each field in sparse tensor type. Note that
160-
/// for all dynamic memrefs, the memory size is really the capacity of
161-
/// the "vector", while the actual size resides in the sizes array.
162-
///
163-
/// TODO: for efficiency, we will need heuristics to make educated guesses
164-
/// on the required capacities (see heuristic variable).
165-
///
163+
/// for all dynamic memrefs in the sparse tensor stroage layout, the
164+
/// memory size is really the capacity of the "vector", while the actual
165+
/// size resides in the sizes array.
166166
static void createAllocFields(OpBuilder &builder, Location loc,
167167
SparseTensorType stt, ValueRange dynSizes,
168168
bool enableInit, SmallVectorImpl<Value> &fields,
@@ -206,6 +206,8 @@ static void createAllocFields(OpBuilder &builder, Location loc,
206206
constantIndex(builder, loc, 16);
207207
}
208208

209+
// Initializes all fields. An initial storage specifier and allocated
210+
// positions/coordinates/values memrefs (with heuristic capacity).
209211
foreachFieldAndTypeInSparseTensor(
210212
stt,
211213
[&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
@@ -218,14 +220,16 @@ static void createAllocFields(OpBuilder &builder, Location loc,
218220
field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
219221
break;
220222
case SparseTensorFieldKind::PosMemRef:
223+
field = createAllocation(builder, loc, cast<MemRefType>(fType),
224+
posHeuristic, enableInit);
225+
break;
221226
case SparseTensorFieldKind::CrdMemRef:
227+
field = createAllocation(builder, loc, cast<MemRefType>(fType),
228+
crdHeuristic, enableInit);
229+
break;
222230
case SparseTensorFieldKind::ValMemRef:
223-
field = createAllocation(
224-
builder, loc, cast<MemRefType>(fType),
225-
(fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
226-
: (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
227-
: valHeuristic,
228-
enableInit);
231+
field = createAllocation(builder, loc, cast<MemRefType>(fType),
232+
valHeuristic, enableInit);
229233
break;
230234
}
231235
assert(field);
@@ -234,21 +238,19 @@ static void createAllocFields(OpBuilder &builder, Location loc,
234238
return true;
235239
});
236240

241+
// Initialize the storage scheme to an empty tensor. Sets the lvlSizes
242+
// and gives all position fields an initial zero entry, so that it is
243+
// easier to maintain the "linear + 1" length property.
237244
MutSparseTensorDescriptor desc(stt, fields);
238-
239-
// Initialize the storage scheme to an empty tensor. Initialized memSizes
240-
// to all zeros, sets the dimSizes to known values and gives all position
241-
// fields an initial zero entry, so that it is easier to maintain the
242-
// "linear + 1" length property.
243245
Value posZero = constantZero(builder, loc, stt.getPosType());
244-
for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
245-
// Fills dim sizes array.
246+
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
246247
// FIXME: `toOrigDim` is deprecated.
247-
desc.setLvlSize(builder, loc, l, dimSizes[toOrigDim(stt.getEncoding(), l)]);
248-
// Pushes a leading zero to positions memref.
249-
if (stt.isCompressedLvl(l))
250-
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
251-
posZero);
248+
desc.setLvlSize(builder, loc, lvl,
249+
dimSizes[toOrigDim(stt.getEncoding(), lvl)]);
250+
const auto dlt = stt.getLvlType(lvl);
251+
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt))
252+
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
253+
/*value=*/posZero);
252254
}
253255
allocSchemeForRank(builder, loc, desc, /*rank=*/0);
254256
}
@@ -347,7 +349,7 @@ static Value genCompressed(OpBuilder &builder, Location loc,
347349
Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
348350
genStore(builder, loc, mszp1, positionsAtLvl, pp1);
349351
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
350-
lvlCoords[lvl]);
352+
/*value=*/lvlCoords[lvl]);
351353
// Prepare the next level "as needed".
352354
if ((lvl + 1) < lvlRank)
353355
allocSchemeForRank(builder, loc, desc, lvl + 1);
@@ -371,8 +373,6 @@ static void genEndInsert(OpBuilder &builder, Location loc,
371373
const Level lvlRank = stt.getLvlRank();
372374
for (Level l = 0; l < lvlRank; l++) {
373375
const auto dlt = stt.getLvlType(l);
374-
if (isLooseCompressedDLT(dlt))
375-
llvm_unreachable("TODO: Not yet implemented");
376376
if (isCompressedDLT(dlt)) {
377377
// Compressed dimensions need a position cleanup for all entries
378378
// that were not visited during the insertion pass.
@@ -407,7 +407,8 @@ static void genEndInsert(OpBuilder &builder, Location loc,
407407
builder.setInsertionPointAfter(loop);
408408
}
409409
} else {
410-
assert(isDenseDLT(dlt) || isSingletonDLT(dlt));
410+
assert(isDenseDLT(dlt) || isLooseCompressedDLT(dlt) ||
411+
isSingletonDLT(dlt) || is2OutOf4DLT(dlt));
411412
}
412413
}
413414
}
@@ -483,33 +484,37 @@ class SparseInsertGenerator
483484
Value value = args.back();
484485
Value parentPos = constantZero(builder, loc, builder.getIndexType());
485486
// Generate code for every level.
486-
for (Level l = 0; l < lvlRank; l++) {
487-
const auto dlt = stt.getLvlType(l);
488-
if (isCompressedDLT(dlt)) {
487+
for (Level lvl = 0; lvl < lvlRank; lvl++) {
488+
const auto dlt = stt.getLvlType(lvl);
489+
if (isCompressedDLT(dlt) || isLooseCompressedDLT(dlt)) {
489490
// Create:
490491
// if (!present) {
491-
// coordinates[l].push_back(coords[l])
492-
// <update positions and prepare level l + 1>
492+
// coordinates[lvl].push_back(coords[lvl])
493+
// <update positions and prepare level lvl + 1>
493494
// }
494-
// positions[l] = coordinates.size() - 1
495-
// <insert @ positions[l] at next level l + 1>
495+
// positions[lvl] = coordinates.size() - 1
496+
// <insert @ positions[lvl] at next level lvl + 1>
497+
if (isLooseCompressedDLT(dlt)) {
498+
Value two = constantIndex(builder, loc, 2);
499+
parentPos = builder.create<arith::MulIOp>(loc, parentPos, two);
500+
}
496501
parentPos =
497-
genCompressed(builder, loc, desc, coords, value, parentPos, l);
498-
} else if (isSingletonDLT(dlt)) {
502+
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
503+
} else if (isSingletonDLT(dlt) || is2OutOf4DLT(dlt)) {
499504
// Create:
500-
// coordinates[l].push_back(coords[l])
501-
// positions[l] = positions[l-1]
502-
// <insert @ positions[l] at next level l + 1>
503-
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
504-
coords[l]);
505+
// coordinates[lvl].push_back(coords[lvl])
506+
// positions[lvl] = positions[lvl-1]
507+
// <insert @ positions[lvl] at next level lvl + 1>
508+
createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef,
509+
lvl, /*value=*/coords[lvl]);
505510
} else {
506511
assert(isDenseDLT(dlt));
507512
// Construct the new position as:
508-
// positions[l] = size * positions[l-1] + coords[l]
509-
// <insert @ positions[l] at next level l + 1>
510-
Value size = desc.getLvlSize(builder, loc, l);
513+
// positions[lvl] = size * positions[lvl-1] + coords[lvl]
514+
// <insert @ positions[lvl] at next level lvl + 1>
515+
Value size = desc.getLvlSize(builder, loc, lvl);
511516
Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
512-
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
517+
parentPos = builder.create<arith::AddIOp>(loc, mult, coords[lvl]);
513518
}
514519
}
515520
// Reached the actual value append/insert.
@@ -526,7 +531,6 @@ class SparseInsertGenerator
526531
// <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
527532
constexpr const char kInsertFuncNamePrefix[] = "_insert_";
528533
const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
529-
530534
SmallString<32> nameBuffer;
531535
llvm::raw_svector_ostream nameOstream(nameBuffer);
532536
nameOstream << kInsertFuncNamePrefix;
@@ -543,8 +547,8 @@ class SparseInsertGenerator
543547
// Static dim sizes are used in the generated code while dynamic sizes are
544548
// loaded from the dimSizes buffer. This is the reason for adding the shape
545549
// to the function name.
546-
for (const auto sh : stt.getDimShape())
547-
nameOstream << sh << "_";
550+
for (const auto sz : stt.getDimShape())
551+
nameOstream << sz << "_";
548552
// Permutation information is also used in generating insertion.
549553
if (!stt.isIdentity())
550554
nameOstream << stt.getDimToLvl() << "_";
@@ -607,7 +611,6 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
607611
assert(retOffset < newCall.getNumResults());
608612
auto retType = ret.getType();
609613
if (failed(typeConverter->convertType(retType, sparseFlat)))
610-
// This should never happen.
611614
llvm_unreachable("Failed to convert type in sparse tensor codegen");
612615

613616
// Converted types can not be empty when the type conversion succeed.
@@ -755,9 +758,7 @@ class SparseTensorAllocConverter
755758
const auto resType = getSparseTensorType(op);
756759
if (!resType.hasEncoding())
757760
return failure();
758-
759-
// Construct allocation for each field.
760-
const Location loc = op.getLoc();
761+
Location loc = op.getLoc();
761762
if (op.getCopy()) {
762763
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
763764
SmallVector<Value> fields;
@@ -778,18 +779,18 @@ class SparseTensorAllocConverter
778779
return success();
779780
}
780781

781-
const Value sizeHint = op.getSizeHint();
782-
const ValueRange dynSizes = adaptor.getDynamicSizes();
782+
// Construct allocation for each field.
783+
Value sizeHint = op.getSizeHint();
784+
ValueRange dynSizes = adaptor.getDynamicSizes();
783785
const size_t found = dynSizes.size();
784786
const int64_t expected = resType.getNumDynamicDims();
785787
if (found != static_cast<size_t>(expected))
786-
return rewriter.notifyMatchFailure(
787-
op, llvm::formatv(
788-
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
789-
found, expected));
788+
return rewriter.notifyMatchFailure(op,
789+
"Got wrong number of dynamic sizes");
790790
SmallVector<Value> fields;
791791
createAllocFields(rewriter, loc, resType, dynSizes,
792792
enableBufferInitialization, fields, sizeHint);
793+
793794
// Replace operation with resulting memrefs.
794795
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
795796
return success();
@@ -817,19 +818,18 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
817818
return failure();
818819

819820
// Construct allocation for each field.
820-
const Location loc = op.getLoc();
821-
const Value sizeHint; // none
821+
Location loc = op.getLoc();
822+
Value sizeHint; // none
822823
const ValueRange dynSizes = adaptor.getDynamicSizes();
823824
const size_t found = dynSizes.size();
824825
const int64_t expected = resType.getNumDynamicDims();
825826
if (found != static_cast<size_t>(expected))
826-
return rewriter.notifyMatchFailure(
827-
op, llvm::formatv(
828-
"Got wrong number of dynamic sizes: Found={0}, Expected={1}",
829-
found, expected));
827+
return rewriter.notifyMatchFailure(op,
828+
"Got wrong number of dynamic sizes");
830829
SmallVector<Value> fields;
831830
createAllocFields(rewriter, loc, resType, dynSizes,
832831
enableBufferInitialization, fields, sizeHint);
832+
833833
// Replace operation with resulting memrefs.
834834
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
835835
return success();
@@ -1496,7 +1496,6 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
14961496
SmallVector<Value> fields;
14971497
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
14981498
fields, nse);
1499-
MutSparseTensorDescriptor desc(dstTp, fields);
15001499

15011500
// Now construct the dim2lvl and lvl2dim buffers.
15021501
Value dim2lvlBuffer;
@@ -1505,6 +1504,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
15051504
dim2lvlBuffer, lvl2dimBuffer);
15061505

15071506
// Read the COO tensor data.
1507+
MutSparseTensorDescriptor desc(dstTp, fields);
15081508
Value xs = desc.getAOSMemRef();
15091509
Value ys = desc.getValMemRef();
15101510
const Type boolTp = rewriter.getIntegerType(1);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,13 @@ class SparseTensorAllocConverter
380380
LogicalResult
381381
matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
382382
ConversionPatternRewriter &rewriter) const override {
383-
if (op.getCopy())
384-
return rewriter.notifyMatchFailure(op,
385-
"sparse tensor copy not implemented");
386-
Location loc = op.getLoc();
387383
const auto stt = getSparseTensorType(op);
388384
if (!stt.hasEncoding())
389385
return failure();
386+
if (op.getCopy())
387+
return rewriter.notifyMatchFailure(op, "alloc copy not implemented");
390388
// Gather all dimension sizes as SSA values.
389+
Location loc = op.getLoc();
391390
const Dimension dimRank = stt.getDimRank();
392391
SmallVector<Value> dimSizes;
393392
dimSizes.reserve(dimRank);

0 commit comments

Comments
 (0)