Skip to content

Commit 2323f48

Browse files
authored
[mlir][sparse] refactor dim2lvl/lvl2dim lvlsizes setup (#72474)
This change provides access to the individual components of dim sizes and lvl sizes after each codegenutil call. This is step 2 out of 3 to make sparse_tensor.new work for BSR
1 parent c6b95f3 commit 2323f48

File tree

4 files changed

+60
-60
lines changed

4 files changed

+60
-60
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -639,25 +639,20 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
639639
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
640640
}
641641

642-
void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
643-
SparseTensorType stt,
644-
SmallVectorImpl<Value> &out) {
645-
out.clear();
646-
out.reserve(stt.getDimRank());
647-
for (const Size sz : stt.getDimShape()) {
648-
const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
649-
out.push_back(constantIndex(builder, loc, s));
650-
}
651-
}
652-
653642
Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
654643
SparseTensorType stt, Value tensor,
655-
/*out*/ SmallVectorImpl<Value> &dimShapesValues,
644+
/*out*/ SmallVectorImpl<Value> &dimSizesValues,
656645
/*out*/ Value &dimSizesBuffer) {
657-
// Construct the dimShapes buffer. The buffer contains the static size
658-
// per dimension, or otherwise a zero for a dynamic size.
659-
fillDimShape(builder, loc, stt, dimShapesValues);
660-
Value dimShapesBuffer = allocaBuffer(builder, loc, dimShapesValues);
646+
// Construct the dimension **shapes** buffer. The buffer contains the static
647+
// size per dimension, or otherwise a zero for a dynamic size.
648+
Dimension dimRank = stt.getDimRank();
649+
dimSizesValues.clear();
650+
dimSizesValues.reserve(dimRank);
651+
for (const Size sz : stt.getDimShape()) {
652+
const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
653+
dimSizesValues.push_back(constantIndex(builder, loc, s));
654+
}
655+
Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
661656
// Create the `CheckedSparseTensorReader`. This reader performs a
662657
// consistency check on the static sizes, but accepts any size
663658
// of each dimension with a dynamic size.
@@ -679,29 +674,40 @@ Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
679674
createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
680675
reader, EmitCInterface::On)
681676
.getResult(0);
677+
// Also convert the dim shapes values into dim sizes values, just in case
678+
// subsequent clients need the values (DCE will remove unused).
679+
for (Dimension d = 0; d < dimRank; d++) {
680+
if (stt.isDynamicDim(d))
681+
dimSizesValues[d] = builder.create<memref::LoadOp>(
682+
loc, dimSizesBuffer, constantIndex(builder, loc, d));
683+
}
682684
}
683685
return reader;
684686
}
685687

686-
Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
687-
SparseTensorType stt,
688-
ArrayRef<Value> dimShapesValues,
689-
Value dimSizesBuffer,
690-
/*out*/ Value &dim2lvlBuffer,
691-
/*out*/ Value &lvl2dimBuffer) {
688+
Value sparse_tensor::genMapBuffers(
689+
OpBuilder &builder, Location loc, SparseTensorType stt,
690+
ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
691+
/*out*/ SmallVectorImpl<Value> &lvlSizesValues,
692+
/*out*/ Value &dim2lvlBuffer,
693+
/*out*/ Value &lvl2dimBuffer) {
692694
const Dimension dimRank = stt.getDimRank();
693695
const Level lvlRank = stt.getLvlRank();
696+
lvlSizesValues.clear();
697+
lvlSizesValues.reserve(lvlRank);
694698
// For an identity mapping, the dim2lvl and lvl2dim mappings are
695699
// identical as are dimSizes and lvlSizes, so buffers are reused
696700
// as much as possible.
697701
if (stt.isIdentity()) {
698702
assert(dimRank == lvlRank);
699703
SmallVector<Value> iotaValues;
700704
iotaValues.reserve(lvlRank);
701-
for (Level l = 0; l < lvlRank; l++)
705+
for (Level l = 0; l < lvlRank; l++) {
702706
iotaValues.push_back(constantIndex(builder, loc, l));
707+
lvlSizesValues.push_back(dimSizesValues[l]);
708+
}
703709
dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
704-
return dimSizesBuffer;
710+
return dimSizesBuffer; // now lvlSizesBuffer
705711
}
706712
// Otherwise, some code needs to be generated to set up the buffers.
707713
// This code deals with permutations as well as non-permutations that
@@ -710,7 +716,6 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
710716
const auto lvlToDim = stt.getLvlToDim();
711717
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
712718
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
713-
SmallVector<Value> lvlSizesValues(lvlRank);
714719
// Generate dim2lvl.
715720
assert(lvlRank == dimToLvl.getNumResults());
716721
for (Level l = 0; l < lvlRank; l++) {
@@ -748,17 +753,14 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
748753
// (3) l = d % c : c
749754
Value lvlSz;
750755
if (cm == 0) {
751-
lvlSz = dimShapesValues[d];
752-
if (stt.isDynamicDim(d))
753-
lvlSz = builder.create<memref::LoadOp>(loc, dimSizesBuffer,
754-
constantIndex(builder, loc, d));
756+
lvlSz = dimSizesValues[d];
755757
if (cf != 0)
756758
lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
757759
constantIndex(builder, loc, cf));
758760
} else {
759761
lvlSz = constantIndex(builder, loc, cm);
760762
}
761-
lvlSizesValues[l] = lvlSz;
763+
lvlSizesValues.push_back(lvlSz);
762764
}
763765
// Generate lvl2dim.
764766
assert(dimRank == lvlToDim.getNumResults());
@@ -792,5 +794,5 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
792794
// Return buffers.
793795
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
794796
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
795-
return allocaBuffer(builder, loc, lvlSizesValues);
797+
return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
796798
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,16 @@ Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor,
317317
Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor,
318318
Dimension dim);
319319

320-
/// Populates the array with the dimension-shape of the given
321-
/// `SparseTensorType`, where dynamic sizes are represented by zero.
322-
void fillDimShape(OpBuilder &builder, Location loc, SparseTensorType stt,
323-
SmallVectorImpl<Value> &out);
324-
325320
/// Generates code that opens a reader and sets the dimension sizes.
326321
Value genReader(OpBuilder &builder, Location loc, SparseTensorType stt,
327322
Value tensor,
328-
/*out*/ SmallVectorImpl<Value> &dimShapeValues,
323+
/*out*/ SmallVectorImpl<Value> &dimSizesValues,
329324
/*out*/ Value &dimSizesBuffer);
330325

331326
/// Generates code to set up the buffer parameters for a map.
332327
Value genMapBuffers(OpBuilder &builder, Location loc, SparseTensorType stt,
333-
ArrayRef<Value> dimShapeValues, Value dimSizesBuffer,
328+
ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
329+
/*out*/ SmallVectorImpl<Value> &lvlSizesValues,
334330
/*out*/ Value &dim2lvlBuffer,
335331
/*out*/ Value &lvl2dimBuffer);
336332

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,11 +1484,12 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
14841484
createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
14851485
fields, nse);
14861486

1487-
// Now construct the dim2lvl and lvl2dim buffers.
1487+
// Now construct the lvl sizes and the dim2lvl/lvl2dim buffers.
1488+
SmallVector<Value> lvlSizesValues;
14881489
Value dim2lvlBuffer;
14891490
Value lvl2dimBuffer;
14901491
genMapBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
1491-
dim2lvlBuffer, lvl2dimBuffer);
1492+
lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
14921493

14931494
// Read the COO tensor data.
14941495
MutSparseTensorDescriptor desc(dstTp, fields);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,10 @@ class NewCallParams final {
199199
params[kParamDimSizes] = dimSizesBuffer
200200
? dimSizesBuffer
201201
: allocaBuffer(builder, loc, dimSizesValues);
202-
params[kParamLvlSizes] =
203-
genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
204-
params[kParamDim2Lvl], params[kParamLvl2Dim]);
202+
SmallVector<Value> lvlSizesValues; // unused
203+
params[kParamLvlSizes] = genMapBuffers(
204+
builder, loc, stt, dimSizesValues, params[kParamDimSizes],
205+
lvlSizesValues, params[kParamDim2Lvl], params[kParamLvl2Dim]);
205206
// Secondary and primary types encoding.
206207
const auto enc = stt.getEncoding();
207208
params[kParamPosTp] = constantPosTypeEncoding(builder, loc, enc);
@@ -369,13 +370,13 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
369370
if (!stt.hasEncoding())
370371
return failure();
371372
// Construct the `reader` opening method calls.
372-
SmallVector<Value> dimShapesValues;
373+
SmallVector<Value> dimSizesValues;
373374
Value dimSizesBuffer;
374375
Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
375-
dimShapesValues, dimSizesBuffer);
376+
dimSizesValues, dimSizesBuffer);
376377
// Use the `reader` to parse the file.
377378
Value tensor = NewCallParams(rewriter, loc)
378-
.genBuffers(stt, dimShapesValues, dimSizesBuffer)
379+
.genBuffers(stt, dimSizesValues, dimSizesBuffer)
379380
.genNewCall(Action::kFromReader, reader);
380381
// Free the memory for `reader`.
381382
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
@@ -402,19 +403,19 @@ class SparseTensorAllocConverter
402403
// Gather all dimension sizes as SSA values.
403404
Location loc = op.getLoc();
404405
const Dimension dimRank = stt.getDimRank();
405-
SmallVector<Value> dimSizes;
406-
dimSizes.reserve(dimRank);
406+
SmallVector<Value> dimSizesValues;
407+
dimSizesValues.reserve(dimRank);
407408
unsigned operandCtr = 0;
408409
for (Dimension d = 0; d < dimRank; d++) {
409-
dimSizes.push_back(
410+
dimSizesValues.push_back(
410411
stt.isDynamicDim(d)
411412
? adaptor.getOperands()[operandCtr++]
412413
: constantIndex(rewriter, loc, op.getStaticSize(d)));
413414
}
414415
// Generate the call to construct empty tensor. The sizes are
415416
// explicitly defined by the arguments to the alloc operator.
416417
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
417-
.genBuffers(stt, dimSizes)
418+
.genBuffers(stt, dimSizesValues)
418419
.genNewCall(Action::kEmpty));
419420
return success();
420421
}
@@ -433,19 +434,19 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
433434
return failure();
434435
// Gather all dimension sizes as SSA values.
435436
const Dimension dimRank = stt.getDimRank();
436-
SmallVector<Value> dimSizes;
437-
dimSizes.reserve(dimRank);
437+
SmallVector<Value> dimSizesValues;
438+
dimSizesValues.reserve(dimRank);
438439
auto shape = op.getType().getShape();
439440
unsigned operandCtr = 0;
440441
for (Dimension d = 0; d < dimRank; d++) {
441-
dimSizes.push_back(stt.isDynamicDim(d)
442-
? adaptor.getOperands()[operandCtr++]
443-
: constantIndex(rewriter, loc, shape[d]));
442+
dimSizesValues.push_back(stt.isDynamicDim(d)
443+
? adaptor.getOperands()[operandCtr++]
444+
: constantIndex(rewriter, loc, shape[d]));
444445
}
445446
// Generate the call to construct empty tensor. The sizes are
446447
// explicitly defined by the arguments to the alloc operator.
447448
rewriter.replaceOp(op, NewCallParams(rewriter, loc)
448-
.genBuffers(stt, dimSizes)
449+
.genBuffers(stt, dimSizesValues)
449450
.genNewCall(Action::kEmpty));
450451
return success();
451452
}
@@ -467,8 +468,8 @@ class SparseTensorReorderCOOConverter
467468
const Value src = adaptor.getInputCoo();
468469

469470
NewCallParams params(rewriter, loc);
470-
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
471-
rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizes)
471+
SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, srcTp, src);
472+
rewriter.replaceOp(op, params.genBuffers(dstTp, dimSizesValues)
472473
.genNewCall(Action::kSortCOOInPlace, src));
473474

474475
return success();
@@ -706,14 +707,14 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
706707
const Location loc = op->getLoc();
707708
const auto dstTp = getSparseTensorType(op.getResult());
708709
assert(dstTp.hasStaticDimShape());
709-
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
710+
SmallVector<Value> dimSizesValues = getDimSizes(rewriter, loc, dstTp);
710711
// Use a library method to transfer the external buffers from
711712
// clients to the internal SparseTensorStorage. Since we cannot
712713
// assume clients transfer ownership of the buffers, this method
713714
// will copy all data over into a new SparseTensorStorage.
714715
Value dst =
715716
NewCallParams(rewriter, loc)
716-
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
717+
.genBuffers(dstTp.withoutDimToLvl(), dimSizesValues)
717718
.genNewCall(Action::kPack,
718719
genLvlPtrsBuffers(rewriter, loc, adaptor.getLevels(),
719720
adaptor.getValues()));

0 commit comments

Comments
 (0)