Skip to content

Commit bc61122

Browse files
authored
[mlir][sparse] reformat SparseTensorCodegen file (llvm#71231)
1 parent 1bdb166 commit bc61122

File tree

1 file changed

+93
-92
lines changed

1 file changed

+93
-92
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 93 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
// visible buffers and actual compiler IR that implements these primitives on
1111
// the selected sparse tensor storage schemes. This pass provides an alternative
1212
// to the SparseTensorConversion pass, eliminating the dependence on a runtime
13-
// support library, and providing much more opportunities for subsequent
14-
// compiler optimization of the generated code.
13+
// support library (other than for file I/O), and providing many more
14+
// opportunities for subsequent compiler optimization of the generated code.
1515
//
1616
//===----------------------------------------------------------------------===//
1717

@@ -37,16 +37,11 @@
3737
using namespace mlir;
3838
using namespace mlir::sparse_tensor;
3939

40-
namespace {
41-
42-
using FuncGeneratorType =
43-
function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
44-
4540
//===----------------------------------------------------------------------===//
4641
// Helper methods.
4742
//===----------------------------------------------------------------------===//
4843

49-
/// Flatten a list of operands that may contain sparse tensors.
44+
/// Flattens a list of operands that may contain sparse tensors.
5045
static void flattenOperands(ValueRange operands,
5146
SmallVectorImpl<Value> &flattened) {
5247
// In case of
@@ -97,6 +92,7 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
9792
return forOp;
9893
}
9994

95+
/// Creates a push back operation.
10096
static void createPushback(OpBuilder &builder, Location loc,
10197
MutSparseTensorDescriptor desc,
10298
SparseTensorFieldKind kind, std::optional<Level> lvl,
@@ -368,6 +364,95 @@ static Value genCompressed(OpBuilder &builder, Location loc,
368364
return ifOp2.getResult(o);
369365
}
370366

367+
/// Generates insertion finalization code.
368+
static void genEndInsert(OpBuilder &builder, Location loc,
369+
SparseTensorDescriptor desc) {
370+
const SparseTensorType stt(desc.getRankedTensorType());
371+
const Level lvlRank = stt.getLvlRank();
372+
for (Level l = 0; l < lvlRank; l++) {
373+
const auto dlt = stt.getLvlType(l);
374+
if (isLooseCompressedDLT(dlt))
375+
llvm_unreachable("TODO: Not yet implemented");
376+
if (isCompressedDLT(dlt)) {
377+
// Compressed dimensions need a position cleanup for all entries
378+
// that were not visited during the insertion pass.
379+
//
380+
// TODO: avoid cleanup and keep compressed scheme consistent at all
381+
// times?
382+
//
383+
if (l > 0) {
384+
Type posType = stt.getPosType();
385+
Value posMemRef = desc.getPosMemRef(l);
386+
Value hi = desc.getPosMemSize(builder, loc, l);
387+
Value zero = constantIndex(builder, loc, 0);
388+
Value one = constantIndex(builder, loc, 1);
389+
// Vector of only one, but needed by createFor's prototype.
390+
SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
391+
scf::ForOp loop = createFor(builder, loc, hi, inits, one);
392+
Value i = loop.getInductionVar();
393+
Value oldv = loop.getRegionIterArg(0);
394+
Value newv = genLoad(builder, loc, posMemRef, i);
395+
Value posZero = constantZero(builder, loc, posType);
396+
Value cond = builder.create<arith::CmpIOp>(
397+
loc, arith::CmpIPredicate::eq, newv, posZero);
398+
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
399+
cond, /*else*/ true);
400+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
401+
genStore(builder, loc, oldv, posMemRef, i);
402+
builder.create<scf::YieldOp>(loc, oldv);
403+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
404+
builder.create<scf::YieldOp>(loc, newv);
405+
builder.setInsertionPointAfter(ifOp);
406+
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
407+
builder.setInsertionPointAfter(loop);
408+
}
409+
} else {
410+
assert(isDenseDLT(dlt) || isSingletonDLT(dlt));
411+
}
412+
}
413+
}
414+
415+
/// Generates a subview into the sizes.
416+
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
417+
Value sz) {
418+
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
419+
return builder
420+
.create<memref::SubViewOp>(
421+
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
422+
ValueRange{}, ValueRange{sz}, ValueRange{},
423+
ArrayRef<int64_t>{0}, // static offset
424+
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
425+
ArrayRef<int64_t>{1}) // static stride
426+
.getResult();
427+
}
428+
429+
/// Creates the reassociation array.
430+
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
431+
ReassociationIndices reassociation;
432+
for (int i = 0, e = srcTp.getRank(); i < e; i++)
433+
reassociation.push_back(i);
434+
return reassociation;
435+
}
436+
437+
/// Generates scalar to tensor cast.
438+
static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
439+
Type dstTp) {
440+
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
441+
// Scalars can only be converted to 0-ranked tensors.
442+
if (rtp.getRank() != 0)
443+
return nullptr;
444+
elem = genCast(builder, loc, elem, rtp.getElementType());
445+
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
446+
}
447+
return genCast(builder, loc, elem, dstTp);
448+
}
449+
450+
//===----------------------------------------------------------------------===//
451+
// Codegen rules.
452+
//===----------------------------------------------------------------------===//
453+
454+
namespace {
455+
371456
/// Helper class to help lowering sparse_tensor.insert operation.
372457
class SparseInsertGenerator
373458
: public FuncCallOrInlineGenerator<SparseInsertGenerator> {
@@ -472,90 +557,6 @@ class SparseInsertGenerator
472557
TensorType rtp;
473558
};
474559

475-
/// Generations insertion finalization code.
476-
static void genEndInsert(OpBuilder &builder, Location loc,
477-
SparseTensorDescriptor desc) {
478-
const SparseTensorType stt(desc.getRankedTensorType());
479-
const Level lvlRank = stt.getLvlRank();
480-
for (Level l = 0; l < lvlRank; l++) {
481-
const auto dlt = stt.getLvlType(l);
482-
if (isLooseCompressedDLT(dlt))
483-
llvm_unreachable("TODO: Not yet implemented");
484-
if (isCompressedDLT(dlt)) {
485-
// Compressed dimensions need a position cleanup for all entries
486-
// that were not visited during the insertion pass.
487-
//
488-
// TODO: avoid cleanup and keep compressed scheme consistent at all
489-
// times?
490-
//
491-
if (l > 0) {
492-
Type posType = stt.getPosType();
493-
Value posMemRef = desc.getPosMemRef(l);
494-
Value hi = desc.getPosMemSize(builder, loc, l);
495-
Value zero = constantIndex(builder, loc, 0);
496-
Value one = constantIndex(builder, loc, 1);
497-
// Vector of only one, but needed by createFor's prototype.
498-
SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
499-
scf::ForOp loop = createFor(builder, loc, hi, inits, one);
500-
Value i = loop.getInductionVar();
501-
Value oldv = loop.getRegionIterArg(0);
502-
Value newv = genLoad(builder, loc, posMemRef, i);
503-
Value posZero = constantZero(builder, loc, posType);
504-
Value cond = builder.create<arith::CmpIOp>(
505-
loc, arith::CmpIPredicate::eq, newv, posZero);
506-
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
507-
cond, /*else*/ true);
508-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
509-
genStore(builder, loc, oldv, posMemRef, i);
510-
builder.create<scf::YieldOp>(loc, oldv);
511-
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
512-
builder.create<scf::YieldOp>(loc, newv);
513-
builder.setInsertionPointAfter(ifOp);
514-
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
515-
builder.setInsertionPointAfter(loop);
516-
}
517-
} else {
518-
assert(isDenseDLT(dlt) || isSingletonDLT(dlt));
519-
}
520-
}
521-
}
522-
523-
static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
524-
Value sz) {
525-
auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
526-
return builder
527-
.create<memref::SubViewOp>(
528-
loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
529-
ValueRange{}, ValueRange{sz}, ValueRange{},
530-
ArrayRef<int64_t>{0}, // static offset
531-
ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
532-
ArrayRef<int64_t>{1}) // static stride
533-
.getResult();
534-
}
535-
536-
static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
537-
ReassociationIndices reassociation;
538-
for (int i = 0, e = srcTp.getRank(); i < e; i++)
539-
reassociation.push_back(i);
540-
return reassociation;
541-
}
542-
543-
static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
544-
Type dstTp) {
545-
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
546-
// Scalars can only be converted to 0-ranked tensors.
547-
if (rtp.getRank() != 0)
548-
return nullptr;
549-
elem = genCast(builder, loc, elem, rtp.getElementType());
550-
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
551-
}
552-
return genCast(builder, loc, elem, dstTp);
553-
}
554-
555-
//===----------------------------------------------------------------------===//
556-
// Codegen rules.
557-
//===----------------------------------------------------------------------===//
558-
559560
/// Sparse tensor storage conversion rule for returns.
560561
class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
561562
public:

0 commit comments

Comments
 (0)