Skip to content

[mlir][sparse] move all COO related methods into SparseTensorType #73881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);

/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);

/// Returns true iff the given type is a COO type where the last level
/// is unique.
bool isUniqueCOOType(Type tp);

/// Returns the starting level for a trailing COO region that spans
/// at least two levels. If no such COO region is found, then returns
/// the level-rank.
Level getCOOStart(SparseTensorEncodingAttr enc);

/// Returns true iff MLIR operand has any sparse operand.
inline bool hasAnySparseOperand(Operation *op) {
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {
Expand Down
20 changes: 18 additions & 2 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ class SparseTensorType {
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}

// TODO: remove?
SparseTensorType(SparseTensorEncodingAttr enc)
: SparseTensorType(RankedTensorType::get(
SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
Float32Type::get(enc.getContext()), enc)) {}

SparseTensorType &operator=(const SparseTensorType &) = delete;
SparseTensorType(const SparseTensorType &) = default;

Expand Down Expand Up @@ -234,9 +240,9 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}

/// Returns the type with an identity mapping.
RankedTensorType getDemappedType() const {
auto lvlShape = getLvlShape();
return RankedTensorType::get(lvlShape, rtp.getElementType(),
return RankedTensorType::get(getLvlShape(), getElementType(),
enc.withoutDimToLvl());
}

Expand Down Expand Up @@ -311,6 +317,16 @@ class SparseTensorType {
return IndexType::get(getContext());
}

/// Returns true iff this sparse tensor type has a trailing
/// COO region starting at the given level. By default, it
/// tests for a unique COO type at top level.
bool isCOOType(Level startLvl = 0, bool isUnique = true) const;

/// Returns the starting level of this sparse tensor type for a
/// trailing COO region that spans **at least** two levels. If
/// no such COO region is found, then returns the level-rank.
Level getCOOStart() const;

/// Returns [un]ordered COO type for this sparse tensor type.
RankedTensorType getCOOType(bool ordered) const;

Expand Down
80 changes: 36 additions & 44 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void StorageLayout::foreachField(
callback) const {
const auto lvlTypes = enc.getLvlTypes();
const Level lvlRank = enc.getLvlRank();
const Level cooStart = getCOOStart(enc);
const Level cooStart = SparseTensorType(enc).getCOOStart();
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
FieldIndex fieldIdx = kDataFieldStartingIdx;
// Per-level storage.
Expand Down Expand Up @@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
unsigned stride = 1;
if (kind == SparseTensorFieldKind::CrdMemRef) {
assert(lvl.has_value());
const Level cooStart = getCOOStart(enc);
const Level cooStart = SparseTensorType(enc).getCOOStart();
const Level lvlRank = enc.getLvlRank();
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
lvl = cooStart;
Expand Down Expand Up @@ -710,6 +710,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
// SparseTensorType Methods.
//===----------------------------------------------------------------------===//

bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
bool isUnique) const {
if (!hasEncoding())
return false;
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
return false;
for (Level l = startLvl + 1; l < lvlRank; ++l)
if (!isSingletonLvl(l))
return false;
// If isUnique is true, then make sure that the last level is unique,
// that is, when lvlRank == 1, the only compressed level is unique,
// and when lvlRank > 1, the last singleton is unique.
return !isUnique || isUniqueLvl(lvlRank - 1);
}

Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
if (hasEncoding() && lvlRank > 1)
for (Level l = 0; l < lvlRank - 1; l++)
if (isCOOType(l, /*isUnique=*/false))
return l;
return lvlRank;
}

RankedTensorType
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
SmallVector<LevelType> lvlTypes;
Expand Down Expand Up @@ -859,25 +882,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
return !coeffientMap.empty();
}

bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
Level startLvl, bool isUnique) {
if (!enc ||
!(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
return false;
const Level lvlRank = enc.getLvlRank();
for (Level l = startLvl + 1; l < lvlRank; ++l)
if (!enc.isSingletonLvl(l))
return false;
// If isUnique is true, then make sure that the last level is unique,
// that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
// (unique on the last singleton).
return !isUnique || enc.isUniqueLvl(lvlRank - 1);
}

bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
}

bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
auto hasNonIdentityMap = [](Value v) {
auto stt = tryGetSparseTensorType(v);
Expand All @@ -888,17 +892,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
llvm::any_of(op->getResults(), hasNonIdentityMap);
}

Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
// We only consider COO region with at least two levels for the purpose
// of AOS storage optimization.
const Level lvlRank = enc.getLvlRank();
if (lvlRank > 1)
for (Level l = 0; l < lvlRank - 1; l++)
if (isCOOType(enc, l, /*isUnique=*/false))
return l;
return lvlRank;
}

Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
if (enc) {
assert(enc.isPermutation() && "Non permutation map not supported");
Expand Down Expand Up @@ -1013,7 +1006,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
return op->emitError("the sparse-tensor must have the identity mapping");

// Verifies the trailing COO.
Level cooStartLvl = getCOOStart(stt.getEncoding());
Level cooStartLvl = stt.getCOOStart();
if (cooStartLvl < stt.getLvlRank()) {
// We only supports trailing COO for now, must be the last input.
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
Expand Down Expand Up @@ -1309,34 +1302,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
}

LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
return emitError("requested level is out of bounds");
if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
return emitError("unexpected type for positions");
return success();
}

LogicalResult ToCoordinatesOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
return emitError("requested level is out of bounds");
if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
return emitError("unexpected type for coordinates");
return success();
}

LogicalResult ToCoordinatesBufferOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (getCOOStart(e) >= e.getLvlRank())
auto stt = getSparseTensorType(getTensor());
if (stt.getCOOStart() >= stt.getLvlRank())
return emitError("expected sparse tensor with a COO region");
return success();
}

LogicalResult ToValuesOp::verify() {
auto ttp = getRankedTensorType(getTensor());
auto stt = getSparseTensorType(getTensor());
auto mtp = getMemRefType(getResult());
if (ttp.getElementType() != mtp.getElementType())
if (stt.getElementType() != mtp.getElementType())
return emitError("unexpected mismatch in element types");
return success();
}
Expand Down Expand Up @@ -1660,9 +1653,8 @@ LogicalResult ReorderCOOOp::verify() {
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
SparseTensorType dstStt = getSparseTensorType(getResultCoo());

if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
!isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
emitError("Unexpected non-COO sparse tensors");
if (!srcStt.isCOOType() || !dstStt.isCOOType())
emitError("Expected COO sparse tensors only");

if (!srcStt.hasSameDimToLvl(dstStt))
emitError("Unmatched dim2lvl map between input and result COO");
Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,7 @@ void LoopEmitter::initializeLoopEmit(
auto stt = getSparseTensorType(tensor);
const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
const auto enc = getSparseTensorEncoding(rtp);
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
const Level cooStart = stt.getCOOStart();

SmallVector<Value> lvlSzs;
for (Level l = 0; l < stt.getLvlRank(); l++) {
Expand Down Expand Up @@ -457,8 +456,8 @@ void LoopEmitter::initializeLoopEmit(
// values.
// Delegates extra output initialization to clients.
bool isOutput = isOutputTensor(t);
Type elementType = rtp.getElementType();
if (!enc) {
Type elementType = stt.getElementType();
if (!stt.hasEncoding()) {
// Non-annotated dense tensors.
BaseMemRefType denseTp = MemRefType::get(shape, elementType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
valHeuristic =
builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
} else if (sizeHint) {
if (getCOOStart(stt.getEncoding()) == 0) {
if (stt.getCOOStart() == 0) {
posHeuristic = constantIndex(builder, loc, 2);
crdHeuristic = builder.create<arith::MulIOp>(
loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
Expand Down Expand Up @@ -657,8 +657,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {

// Should have been verified.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));
dstStt.isCOOType() && srcStt.isCOOType());
assert(dstStt.hasSameDimToLvl(srcStt));

// We don't need a mutable descriptor here as we perform sorting in-place.
Expand Down Expand Up @@ -1317,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
Value posBack = c0; // index to the last value in the position array
Value memSize = c1; // memory size for current array

Level trailCOOStart = getCOOStart(stt.getEncoding());
Level trailCOOStart = stt.getCOOStart();
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
// Sets up SparseTensorSpecifier.
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
Expand Down Expand Up @@ -1454,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
const auto dstTp = getSparseTensorType(op.getResult());
// Creating COO with NewOp is handled by direct IR codegen. All other cases
// are handled by rewriting.
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
return failure();

// Implement as follows:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,

Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
OpBuilder &builder, Location loc, Level lvl) const {
const Level cooStart = getCOOStart(rType.getEncoding());
const Level cooStart = rType.getCOOStart();
if (lvl < cooStart)
return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
}

Value getAOSMemRef() const {
const Level cooStart = getCOOStart(rType.getEncoding());
const Level cooStart = rType.getCOOStart();
assert(cooStart < rType.getLvlRank());
return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1180,8 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
auto enc = stt.getEncoding();
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
if (!stt.hasEncoding() || stt.getCOOStart() == 0)
return failure();

// Implement the NewOp as follows:
Expand All @@ -1192,6 +1191,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = cooTensor;
auto enc = stt.getEncoding();
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
Expand Down