-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][sparse] unify support of (dis)assemble between direct IR/lib path #71880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same. Generalizing the ops is still TBD.
@llvm/pr-subscribers-mlir-execution-engine @llvm/pr-subscribers-mlir Author: Aart Bik (aartbik) ChangesNote that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same. Generalizing the ops is still TBD. Patch is 37.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71880.diff 8 Files Affected:
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
// Linearize the address.
- for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
- valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ for (uint64_t l = 0; l < lvlRank; l++)
+ valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
values[valIdx] = val;
return;
}
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
if (isCompressedLvl(l))
return positions[l][parentSz];
- if (isSingletonLvl(l))
- return parentSz; // New size is same as the parent.
- // TODO: support levels assignment for loose/2:4?
+ if (isLooseCompressedLvl(l))
+ return positions[l][2 * parentSz - 1];
+ if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
}
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
dim2lvl, lvl2dim) {
+ // Note that none of the buffers cany be reused because ownership
+ // of the memory passed from clients is not necessarily transferred.
+ // Therefore, all data is copied over into a new SparseTensorStorage.
+ //
+ // TODO: this needs to be generalized to all formats AND
+ // we need a proper audit of e.g. double compressed
+ // levels where some are not filled
+ //
uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
for (uint64_t l = 0; l < lvlRank; l++) {
- if (!isUniqueLvl(l) && isCompressedLvl(l)) {
- // A `compressed_nu` level marks the start of trailing COO start level.
- // Since the coordinate buffer used for trailing COO are passed in as AoS
- // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
- // copy the value from the provided buffers.
+ if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the coordinate buffer used for trailing COO
+ // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply copy the value from the provided buffers.
trailCOOLen = lvlRank - l;
break;
}
- assert(!isSingletonLvl(l) &&
- "Singleton level not following a compressed_nu level");
- if (isCompressedLvl(l)) {
+ if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- // Copies the lvlBuf into the vectors. The buffer can not be simply reused
- // because the memory passed from users is not necessarily allocated on
- // heap.
- positions[l].assign(posPtr, posPtr + parentSz + 1);
- coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ if (!isLooseCompressedLvl(l)) {
+ positions[l].assign(posPtr, posPtr + parentSz + 1);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ } else {
+ positions[l].assign(posPtr, posPtr + 2 * parentSz);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+ }
+ } else if (isSingletonLvl(l)) {
+ assert(0 && "general singleton not supported yet");
+ } else if (is2OutOf4Lvl(l)) {
+ assert(0 && "2Out4 not supported yet");
} else {
- // TODO: support levels assignment for loose/2:4?
assert(isDenseLvl(l));
}
parentSz = assembledSize(parentSz, l);
}
+ // Handle Aos vs. SoA mismatch for COO.
if (trailCOOLen != 0) {
uint64_t cooStartLvl = lvlRank - trailCOOLen;
- assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+ assert(!isUniqueLvl(cooStartLvl) &&
+ (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
- P crdLen = positions[cooStartLvl][parentSz];
+ P crdLen;
+ if (!isLooseCompressedLvl(cooStartLvl)) {
+ positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+ crdLen = positions[cooStartLvl][parentSz];
+ } else {
+ positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+ crdLen = positions[cooStartLvl][2 * parentSz - 1];
+ }
for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
coordinates[l].resize(crdLen);
for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
parentSz = assembledSize(parentSz, cooStartLvl);
}
+ // Copy the values buffer.
V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
values.assign(valPtr, valPtr + parentSz);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+ Value elem, Type dstTp) {
+ if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ // Scalars can only be converted to 0-ranked tensors.
+ assert(rtp.getRank() == 0);
+ elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+ return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+ }
+ return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+ Type dstTp);
+
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
- Type dstTp) {
- if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
- // Scalars can only be converted to 0-ranked tensors.
- if (rtp.getRank() != 0)
- return nullptr;
- elem = genCast(builder, loc, elem, rtp.getElementType());
- return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
- }
- return genCast(builder, loc, elem, dstTp);
-}
-
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
- StringRef name, TypeRange resultType,
- ValueRange operands,
- EmitCInterface emitCInterface) {
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
- emitCInterface);
- return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
- operands);
-}
-
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
};
/// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
- ValueRange ptr) {
- SmallString<15> name{"sparseValues",
- primaryTypeFunctionSuffix(tp.getElementType())};
- return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr) {
+ auto eltTp = stt.getElementType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+ SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type posTp = stt.getPosType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type crdTp = stt.getCrdType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
.getResult(0);
}
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(
stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
dimSizes.reserve(dimRank);
auto shape = op.getType().getShape();
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
: constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
LogicalResult
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resTp = op.getType();
- Type posTp = cast<ShapedType>(resTp).getElementType();
- SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
- Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
- replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
- EmitCInterface::On);
+ auto stt = getSparseTensorType(op.getTensor());
+ auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
+ rewriter.replaceOp(op, poss);
return success();
}
};
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: use `SparseTensorType::getCrdType` instead.
- Type resType = op.getType();
- const Type crdTp = cast<ShapedType>(resType).getElementType();
- SmallString<19> name{"sparseCoordinates",
- overheadTypeFunctionSuffix(crdTp)};
- Location loc = op->getLoc();
- Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
- // The function returns a MemRef without a layout.
- MemRefType callRetType = get1DMemRefType(crdTp, false);
- SmallVector<Value> operands{adaptor.getTensor(), lvl};
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
- operands, EmitCInterface::On);
- Value callRet =
- rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
- .getResult(0);
-
+ auto stt = getSparseTensorType(op.getTensor());
+ auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
- if (resType != callRetType)
- callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
- rewriter.replaceOp(op, callRet);
-
+ if (op.getType() != crds.getType())
+ crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+ rewriter.replaceOp(op, crds);
return success();
}
};
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = cast<ShapedType>(op.getType());
- rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
- adaptor.getOperands()));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ rewriter.replaceOp(op, vals);
return success();
}
};
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
- auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
- Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
- rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
- constantIndex(rewriter, loc, 0));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ auto zero = constantIndex(rewriter, op.getLoc(), 0);
+ rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
return success();
}
};
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
- // AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+ // Use a library method to transfer the external buffers from
+ // clients to the internal SparseTensorStorage. Since we cannot
+ // assume clients transfer ownership of the buffers, this method
+ // will copy all data over into a new SparseTensorStorage.
Value dst =
NewCallParams(rewriter, loc)
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
}
};
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+ : public OpConversionPattern<DisassembleOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // We simply expose the buffers to the external client. This
+ // assumes the client only reads the buffers (usually copying it
+ // to the external data structures, such as numpy arrays).
+ Location loc = op->getLoc();
+ auto stt = getSparseTensorType(op.getTensor());
+ SmallVector<Value> retVal;
+ SmallVector<Value> retLen;
+ // Get the values buffer first.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+ // Then get the positions and coordinates buffers.
+ const Level lvlRank = stt.getLvlRank();
+ Level trailCOOLen = 0;
+ for (Level l = 0; l < lvlRank; l++) {
+ if (!stt.isUniqueLvl(l) &&
+ (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the target coordinate buffer used for trailing
+ // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply use the internal buffers.
+ trailCOOLen = lvlRank - l;
+ break;
+ }
+ if (stt.isWithPos(l)) {
+ auto poss =
+ genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ }
+ if (stt.isWithCrd(l)) {
+ auto crds =
+ genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, lo...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Aart Bik (aartbik) ChangesNote that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same. Generalizing the ops is still TBD. Patch is 37.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71880.diff 8 Files Affected:
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 460549726356370..3382e293d123746 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -301,8 +301,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t lvlRank = getLvlRank();
uint64_t valIdx = 0;
// Linearize the address.
- for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
- valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ for (uint64_t l = 0; l < lvlRank; l++)
+ valIdx = valIdx * getLvlSize(l) + lvlCoords[l];
values[valIdx] = val;
return;
}
@@ -472,9 +472,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t assembledSize(uint64_t parentSz, uint64_t l) const {
if (isCompressedLvl(l))
return positions[l][parentSz];
- if (isSingletonLvl(l))
- return parentSz; // New size is same as the parent.
- // TODO: support levels assignment for loose/2:4?
+ if (isLooseCompressedLvl(l))
+ return positions[l][2 * parentSz - 1];
+ if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
}
@@ -766,40 +767,59 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
const uint64_t *dim2lvl, const uint64_t *lvl2dim, const intptr_t *lvlBufs)
: SparseTensorStorage(dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes,
dim2lvl, lvl2dim) {
+ // Note that none of the buffers cany be reused because ownership
+ // of the memory passed from clients is not necessarily transferred.
+ // Therefore, all data is copied over into a new SparseTensorStorage.
+ //
+ // TODO: this needs to be generalized to all formats AND
+ // we need a proper audit of e.g. double compressed
+ // levels where some are not filled
+ //
uint64_t trailCOOLen = 0, parentSz = 1, bufIdx = 0;
for (uint64_t l = 0; l < lvlRank; l++) {
- if (!isUniqueLvl(l) && isCompressedLvl(l)) {
- // A `compressed_nu` level marks the start of trailing COO start level.
- // Since the coordinate buffer used for trailing COO are passed in as AoS
- // scheme, and SparseTensorStorage uses a SoA scheme, we can not simply
- // copy the value from the provided buffers.
+ if (!isUniqueLvl(l) && (isCompressedLvl(l) || isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the coordinate buffer used for trailing COO
+ // is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply copy the value from the provided buffers.
trailCOOLen = lvlRank - l;
break;
}
- assert(!isSingletonLvl(l) &&
- "Singleton level not following a compressed_nu level");
- if (isCompressedLvl(l)) {
+ if (isCompressedLvl(l) || isLooseCompressedLvl(l)) {
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *crdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- // Copies the lvlBuf into the vectors. The buffer can not be simply reused
- // because the memory passed from users is not necessarily allocated on
- // heap.
- positions[l].assign(posPtr, posPtr + parentSz + 1);
- coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ if (!isLooseCompressedLvl(l)) {
+ positions[l].assign(posPtr, posPtr + parentSz + 1);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][parentSz]);
+ } else {
+ positions[l].assign(posPtr, posPtr + 2 * parentSz);
+ coordinates[l].assign(crdPtr, crdPtr + positions[l][2 * parentSz - 1]);
+ }
+ } else if (isSingletonLvl(l)) {
+ assert(0 && "general singleton not supported yet");
+ } else if (is2OutOf4Lvl(l)) {
+ assert(0 && "2Out4 not supported yet");
} else {
- // TODO: support levels assignment for loose/2:4?
assert(isDenseLvl(l));
}
parentSz = assembledSize(parentSz, l);
}
+ // Handle Aos vs. SoA mismatch for COO.
if (trailCOOLen != 0) {
uint64_t cooStartLvl = lvlRank - trailCOOLen;
- assert(!isUniqueLvl(cooStartLvl) && isCompressedLvl(cooStartLvl));
+ assert(!isUniqueLvl(cooStartLvl) &&
+ (isCompressedLvl(cooStartLvl) || isLooseCompressedLvl(cooStartLvl)));
P *posPtr = reinterpret_cast<P *>(lvlBufs[bufIdx++]);
C *aosCrdPtr = reinterpret_cast<C *>(lvlBufs[bufIdx++]);
- positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
- P crdLen = positions[cooStartLvl][parentSz];
+ P crdLen;
+ if (!isLooseCompressedLvl(cooStartLvl)) {
+ positions[cooStartLvl].assign(posPtr, posPtr + parentSz + 1);
+ crdLen = positions[cooStartLvl][parentSz];
+ } else {
+ positions[cooStartLvl].assign(posPtr, posPtr + 2 * parentSz);
+ crdLen = positions[cooStartLvl][2 * parentSz - 1];
+ }
for (uint64_t l = cooStartLvl; l < lvlRank; l++) {
coordinates[l].resize(crdLen);
for (uint64_t n = 0; n < crdLen; n++) {
@@ -809,6 +829,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
parentSz = assembledSize(parentSz, cooStartLvl);
}
+ // Copy the values buffer.
V *valPtr = reinterpret_cast<V *>(lvlBufs[bufIdx]);
values.assign(valPtr, valPtr + parentSz);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index d5c9ee41215ae97..8e2c2cd6dad7b19 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -163,6 +163,17 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
+Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
+ Value elem, Type dstTp) {
+ if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
+ // Scalars can only be converted to 0-ranked tensors.
+ assert(rtp.getRank() == 0);
+ elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
+ return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
+ }
+ return sparse_tensor::genCast(builder, loc, elem, dstTp);
+}
+
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index 1f53f3525203c70..d3b0889b71b514c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -142,6 +142,10 @@ class FuncCallOrInlineGenerator {
/// Add type casting between arith and index types when needed.
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
+/// Add conversion from scalar to given type (possibly a 0-rank tensor).
+Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
+ Type dstTp);
+
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 08c38394a46343a..888f513be2e4dc7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -435,19 +435,6 @@ static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
return reassociation;
}
-/// Generates scalar to tensor cast.
-static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
- Type dstTp) {
- if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
- // Scalars can only be converted to 0-ranked tensors.
- if (rtp.getRank() != 0)
- return nullptr;
- elem = genCast(builder, loc, elem, rtp.getElementType());
- return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
- }
- return genCast(builder, loc, elem, dstTp);
-}
-
//===----------------------------------------------------------------------===//
// Codegen rules.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 4fe9c59d8c320a7..e629133171e15dc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -46,17 +46,6 @@ static std::optional<Type> convertSparseTensorTypes(Type type) {
return std::nullopt;
}
-/// Replaces the `op` with a `CallOp` to the `getFunc()` function reference.
-static func::CallOp replaceOpWithFuncCall(RewriterBase &rewriter, Operation *op,
- StringRef name, TypeRange resultType,
- ValueRange operands,
- EmitCInterface emitCInterface) {
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, resultType, operands,
- emitCInterface);
- return rewriter.replaceOpWithNewOp<func::CallOp>(op, resultType, fn,
- operands);
-}
-
/// Generates call to lookup a level-size. N.B., this only generates
/// the raw function call, and therefore (intentionally) does not perform
/// any dim<->lvl conversion or other logic.
@@ -264,11 +253,36 @@ class NewCallParams final {
};
/// Generates a call to obtain the values array.
-static Value genValuesCall(OpBuilder &builder, Location loc, ShapedType tp,
- ValueRange ptr) {
- SmallString<15> name{"sparseValues",
- primaryTypeFunctionSuffix(tp.getElementType())};
- return createFuncCall(builder, loc, name, tp, ptr, EmitCInterface::On)
+static Value genValuesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr) {
+ auto eltTp = stt.getElementType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, eltTp);
+ SmallString<15> name{"sparseValues", primaryTypeFunctionSuffix(eltTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr}, EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the positions array.
+static Value genPositionsCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type posTp = stt.getPosType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, posTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
+ .getResult(0);
+}
+
+/// Generates a call to obtain the coordindates array.
+static Value genCoordinatesCall(OpBuilder &builder, Location loc,
+ SparseTensorType stt, Value ptr, Level l) {
+ Type crdTp = stt.getCrdType();
+ auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
+ Value lvl = constantIndex(builder, loc, l);
+ SmallString<19> name{"sparseCoordinates", overheadTypeFunctionSuffix(crdTp)};
+ return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
+ EmitCInterface::On)
.getResult(0);
}
@@ -391,7 +405,7 @@ class SparseTensorAllocConverter
SmallVector<Value> dimSizes;
dimSizes.reserve(dimRank);
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(
stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
@@ -423,7 +437,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
dimSizes.reserve(dimRank);
auto shape = op.getType().getShape();
unsigned operandCtr = 0;
- for (Dimension d = 0; d < dimRank; ++d) {
+ for (Dimension d = 0; d < dimRank; d++) {
dimSizes.push_back(stt.isDynamicDim(d)
? adaptor.getOperands()[operandCtr++]
: constantIndex(rewriter, loc, shape[d]));
@@ -487,12 +501,10 @@ class SparseTensorToPositionsConverter
LogicalResult
matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resTp = op.getType();
- Type posTp = cast<ShapedType>(resTp).getElementType();
- SmallString<17> name{"sparsePositions", overheadTypeFunctionSuffix(posTp)};
- Value lvl = constantIndex(rewriter, op->getLoc(), op.getLevel());
- replaceOpWithFuncCall(rewriter, op, name, resTp, {adaptor.getTensor(), lvl},
- EmitCInterface::On);
+ auto stt = getSparseTensorType(op.getTensor());
+ auto poss = genPositionsCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
+ rewriter.replaceOp(op, poss);
return success();
}
};
@@ -505,29 +517,14 @@ class SparseTensorToCoordinatesConverter
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // TODO: use `SparseTensorType::getCrdType` instead.
- Type resType = op.getType();
- const Type crdTp = cast<ShapedType>(resType).getElementType();
- SmallString<19> name{"sparseCoordinates",
- overheadTypeFunctionSuffix(crdTp)};
- Location loc = op->getLoc();
- Value lvl = constantIndex(rewriter, loc, op.getLevel());
-
- // The function returns a MemRef without a layout.
- MemRefType callRetType = get1DMemRefType(crdTp, false);
- SmallVector<Value> operands{adaptor.getTensor(), lvl};
- auto fn = getFunc(op->getParentOfType<ModuleOp>(), name, callRetType,
- operands, EmitCInterface::On);
- Value callRet =
- rewriter.create<func::CallOp>(loc, callRetType, fn, operands)
- .getResult(0);
-
+ auto stt = getSparseTensorType(op.getTensor());
+ auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
+ adaptor.getTensor(), op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
- if (resType != callRetType)
- callRet = rewriter.create<memref::CastOp>(loc, resType, callRet);
- rewriter.replaceOp(op, callRet);
-
+ if (op.getType() != crds.getType())
+ crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
+ rewriter.replaceOp(op, crds);
return success();
}
};
@@ -539,9 +536,9 @@ class SparseTensorToValuesConverter : public OpConversionPattern<ToValuesOp> {
LogicalResult
matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resType = cast<ShapedType>(op.getType());
- rewriter.replaceOp(op, genValuesCall(rewriter, op.getLoc(), resType,
- adaptor.getOperands()));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ rewriter.replaceOp(op, vals);
return success();
}
};
@@ -554,13 +551,11 @@ class SparseNumberOfEntriesConverter
LogicalResult
matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
// Query values array size for the actually stored values size.
- Type eltType = cast<ShapedType>(op.getTensor().getType()).getElementType();
- auto resTp = MemRefType::get({ShapedType::kDynamic}, eltType);
- Value values = genValuesCall(rewriter, loc, resTp, adaptor.getOperands());
- rewriter.replaceOpWithNewOp<memref::DimOp>(op, values,
- constantIndex(rewriter, loc, 0));
+ auto stt = getSparseTensorType(op.getTensor());
+ auto vals = genValuesCall(rewriter, op.getLoc(), stt, adaptor.getTensor());
+ auto zero = constantIndex(rewriter, op.getLoc(), 0);
+ rewriter.replaceOpWithNewOp<memref::DimOp>(op, vals, zero);
return success();
}
};
@@ -701,7 +696,7 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
}
};
-/// Sparse conversion rule for the sparse_tensor.pack operator.
+/// Sparse conversion rule for the sparse_tensor.assemble operator.
class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
public:
using OpConversionPattern::OpConversionPattern;
@@ -710,9 +705,12 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
- // AssembleOps always returns a static shaped tensor result.
assert(dstTp.hasStaticDimShape());
SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, dstTp);
+ // Use a library method to transfer the external buffers from
+ // clients to the internal SparseTensorStorage. Since we cannot
+ // assume clients transfer ownership of the buffers, this method
+ // will copy all data over into a new SparseTensorStorage.
Value dst =
NewCallParams(rewriter, loc)
.genBuffers(dstTp.withoutDimToLvl(), dimSizes)
@@ -724,6 +722,115 @@ class SparseTensorAssembleConverter : public OpConversionPattern<AssembleOp> {
}
};
+/// Sparse conversion rule for the sparse_tensor.disassemble operator.
+class SparseTensorDisassembleConverter
+ : public OpConversionPattern<DisassembleOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // We simply expose the buffers to the external client. This
+ // assumes the client only reads the buffers (usually copying it
+ // to the external data structures, such as numpy arrays).
+ Location loc = op->getLoc();
+ auto stt = getSparseTensorType(op.getTensor());
+ SmallVector<Value> retVal;
+ SmallVector<Value> retLen;
+ // Get the values buffer first.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+ // Then get the positions and coordinates buffers.
+ const Level lvlRank = stt.getLvlRank();
+ Level trailCOOLen = 0;
+ for (Level l = 0; l < lvlRank; l++) {
+ if (!stt.isUniqueLvl(l) &&
+ (stt.isCompressedLvl(l) || stt.isLooseCompressedLvl(l))) {
+ // A `(loose)compressed_nu` level marks the start of trailing COO
+ // start level. Since the target coordinate buffer used for trailing
+ // COO is passed in as AoS scheme and SparseTensorStorage uses a SoA
+ // scheme, we cannot simply use the internal buffers.
+ trailCOOLen = lvlRank - l;
+ break;
+ }
+ if (stt.isWithPos(l)) {
+ auto poss =
+ genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ retVal.push_back(poss);
+ retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
+ }
+ if (stt.isWithCrd(l)) {
+ auto crds =
+ genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
+ auto crdLen = linalg::createOrFoldDimOp(rewriter, lo...
[truncated]
|
…ath (llvm#71880) Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same. Generalizing the ops is still TBD.
Note that the (dis)assemble operations still make some simplfying assumptions (e.g. trailing 2-D COO in AoS format) but now at least both the direct IR and support library path behave exactly the same.
Generalizing the ops is still TBD.