-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] implement sparse_tensor.reorder_coo #68916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-execution-engine @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesAs a side effect of the change, it also unifies the convertOp implementation between lib/codegen path. Patch is 122.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68916.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 2920ef79f461c6a..ca9555248130f08 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -151,6 +151,8 @@ enum class Action : uint32_t {
kToCOO = 5,
kToIterator = 6,
kPack = 7,
+ // Sort an unordered COO in place.
+ kSortCOOInPlace = 8,
};
/// This enum defines all the sparse representations supportable by
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index afbabb97eb71fc5..9016634fa3be8dd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -200,10 +200,6 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
// Whether the convert can be done by a single step (either a sort or a foreach),
// or it would require a tmp buffer (sort, then foreach).
bool directConvertable();
-
- // Whether the convert is actually a sort coo
- // TODO: The method will be removed when sort_coo operation is introduced.
- bool isSortCOOConvert();
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 303a41bc471d5d9..751ee8d1b17dc37 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -349,6 +349,8 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
~SparseTensorStorage() final = default;
+ void sortInPlace();
+
/// Partially specialize these getter methods based on template types.
void getPositions(std::vector<P> **out, uint64_t lvl) final {
assert(out && "Received nullptr for out parameter");
@@ -374,6 +376,24 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
/// Partially specialize lexicographical insertions based on template types.
void lexInsert(const uint64_t *lvlCoords, V val) final {
assert(lvlCoords && "Received nullptr for level-coordinates");
+ // TODO: get rid of this! canonicalize all-dense "sparse" array into dense
+ // tensors.
+ bool allDense = true;
+ for (DimLevelType lt : getLvlTypes()) {
+ if (!isDenseDLT(lt)) {
+ allDense = false;
+ break;
+ }
+ }
+ if (allDense) {
+ uint64_t lvlRank = getLvlRank();
+ uint64_t valIdx = 0;
+ // Linearize the address
+ for (size_t lvl = 0; lvl < lvlRank; lvl++)
+ valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
+ values[valIdx] = val;
+ return;
+ }
// First, wrap up pending insertion path.
uint64_t diffLvl = 0;
uint64_t full = 0;
@@ -956,6 +976,61 @@ SparseTensorStorage<P, C, V> *SparseTensorStorage<P, C, V>::packFromLvlBuffers(
return tensor;
}
+template <typename P, typename C, typename V>
+void SparseTensorStorage<P, C, V>::sortInPlace() {
+ uint64_t nnz = values.size();
+#ifndef NDEBUG
+ for (uint64_t l = 0; l < getLvlRank(); l++)
+ assert(nnz == coordinates[l].size());
+#endif
+
+ // In-place permutation.
+ auto applyPerm = [this](std::vector<uint64_t> &perm) {
+ size_t length = perm.size();
+ size_t lvlRank = getLvlRank();
+ // Cache for the current level coordinates.
+ std::vector<P> lvlCrds(lvlRank);
+ for (size_t i = 0; i < length; i++) {
+ size_t current = i;
+ if (i != perm[current]) {
+ for (size_t l = 0; l < lvlRank; l++)
+ lvlCrds[l] = coordinates[l][i];
+ V val = values[i];
+ // Deals with a permutation cycle.
+ while (i != perm[current]) {
+ size_t next = perm[current];
+ // Swaps the level coordinates and value.
+ for (size_t l = 0; l < lvlRank; l++)
+ coordinates[l][current] = coordinates[l][next];
+ values[current] = values[next];
+ perm[current] = current;
+ current = next;
+ }
+ for (size_t l = 0; l < lvlRank; l++)
+ coordinates[l][current] = lvlCrds[l];
+ values[current] = val;
+ perm[current] = current;
+ }
+ }
+ };
+
+ std::vector<uint64_t> sortedIdx(nnz, 0);
+ for (uint64_t i = 0; i < nnz; i++)
+ sortedIdx[i] = i;
+
+ std::sort(sortedIdx.begin(), sortedIdx.end(),
+ [this](uint64_t lhs, uint64_t rhs) {
+ for (uint64_t l = 0; l < getLvlRank(); l++) {
+ if (coordinates[l][lhs] == coordinates[l][rhs])
+ continue;
+ return coordinates[l][lhs] < coordinates[l][rhs];
+ }
+ assert(false && "duplicate coordinates");
+ });
+
+ applyPerm(sortedIdx);
+}
+
template <typename P, typename C, typename V>
SparseTensorStorage<P, C, V>::SparseTensorStorage(
uint64_t dimRank, const uint64_t *dimSizes, uint64_t lvlRank,
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index ef9d4fea68628b9..61522fb0dcd24b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1060,20 +1060,12 @@ LogicalResult ConvertOp::verify() {
}
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
- Type dstType = getType();
- // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse
- // convert for codegen to remove. This is because we use trivial
- // sparse-to-sparse convert to tell bufferization that the sparse codegen
- // will expand the tensor buffer into sparse tensor storage.
- if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType())
+ if (getType() == getSource().getType())
return getSource();
return {};
}
bool ConvertOp::directConvertable() {
- if (isSortCOOConvert())
- return false;
-
SparseTensorType srcStt = getSparseTensorType(getSource());
SparseTensorType dstStt = getSparseTensorType(getDest());
@@ -1099,15 +1091,6 @@ bool ConvertOp::directConvertable() {
return false;
}
-bool ConvertOp::isSortCOOConvert() {
- // TODO: we should instead use a different sort_coo operation to handle
- // the conversion between COOs (but with different ordering).
- return isUniqueCOOType(getSource().getType()) &&
- isUniqueCOOType(getDest().getType()) &&
- !getSparseTensorType(getSource()).isAllOrdered() &&
- getSparseTensorType(getDest()).isAllOrdered();
-}
-
LogicalResult ToPositionsOp::verify() {
auto e = getSparseTensorEncoding(getTensor().getType());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 78f5562b392a682..378dd9128839d7f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -680,31 +680,26 @@ class SparseDimOpConverter : public OpConversionPattern<tensor::DimOp> {
};
// TODO: use a new SortCOO operation here instead of reusing convert op.
-struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
+struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(ConvertOp op, ConvertOpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Direct conversion should have already been lowered.
- if (!op.isSortCOOConvert())
- return failure();
-
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
- SparseTensorType srcStt = getSparseTensorType(op.getSource());
- SparseTensorType dstStt = getSparseTensorType(op.getDest());
+ SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
+ SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
- // TODO: This should be verification rules for sort_coo operation.
+ // Should have been verified.
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
isUniqueCOOType(srcStt.getRankedTensorType()) &&
isUniqueCOOType(dstStt.getRankedTensorType()));
-
assert(dstStt.hasSameDimToLvl(srcStt));
// We don't need a mutable descriptor here as we perform sorting in-place.
- auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getSource());
- auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
+ auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
auto crd = desc.getAOSMemRef();
auto val = desc.getValMemRef();
@@ -715,12 +710,11 @@ struct SparseSortCOOConverter : public OpConversionPattern<ConvertOp> {
auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
- rewriter.getIndexAttr(0),
- SparseTensorSortKind::HybridQuickSort);
+ rewriter.getIndexAttr(0), op.getAlgorithm());
// Since we do in-place sorting, the destinate tensor will have the same set
// of memrefs as the source tensor.
- rewriter.replaceOp(op, adaptor.getSource());
+ rewriter.replaceOp(op, adaptor.getInputCoo());
return success();
}
};
@@ -1147,9 +1141,6 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
LogicalResult
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (op.isSortCOOConvert())
- return failure();
-
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
SparseTensorEncodingAttr encSrc =
getSparseTensorEncoding(op.getSource().getType());
@@ -1603,7 +1594,7 @@ void mlir::populateSparseTensorCodegenPatterns(
SparseCastConverter, SparseExtractSliceConverter,
SparseTensorLoadConverter, SparseExpandConverter,
SparseCompressConverter, SparseInsertConverter,
- SparseSortCOOConverter,
+ SparseReorderCOOConverter,
SparseSliceGetterOpConverter<ToSliceOffsetOp,
StorageSpecifierKind::DimOffset>,
SparseSliceGetterOpConverter<ToSliceStrideOp,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index d2d7b46ab834e71..5d9ee6906749ec1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -299,76 +299,6 @@ static void genDelCOOCall(OpBuilder &builder, Location loc, Type elemTp,
createFuncCall(builder, loc, name, {}, coo, EmitCInterface::Off);
}
-/// Generates a call to release/delete a `SparseTensorIterator`.
-static void genDelIteratorCall(OpBuilder &builder, Location loc, Type elemTp,
- Value iter) {
- SmallString<26> name{"delSparseTensorIterator",
- primaryTypeFunctionSuffix(elemTp)};
- createFuncCall(builder, loc, name, {}, iter, EmitCInterface::Off);
-}
-
-/// Generates a call that adds one element to a coordinate scheme.
-/// In particular, this generates code like the following:
-/// val = a[i1,..,ik];
-/// if val != 0
-/// t->add(&val, [i1,..,ik], [p1,..,pk]);
-static void genAddEltCall(OpBuilder &builder, Location loc, Type eltType,
- Value lvlCOO, Value valPtr, Value dimCoords,
- Value dimToLvl) {
- SmallString<9> name{"addElt", primaryTypeFunctionSuffix(eltType)};
- SmallVector<Value, 4> params{lvlCOO, valPtr, dimCoords, dimToLvl};
- Type pTp = getOpaquePointerType(builder);
- createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On);
-}
-
-/// Generates a call to `iter->getNext()`. If there is a next element,
-/// then it is copied into the out-parameters `coords` and `elemPtr`,
-/// and the return value is true. If there isn't a next element, then
-/// the return value is false.
-///
-/// The `coords` argument uses the same coordinate-space as the `iter`
-/// (which can be either dim- or lvl-coords, depending on context).
-static Value genGetNextCall(OpBuilder &builder, Location loc, Value iter,
- Value coords, Value elemPtr) {
- Type elemTp = cast<ShapedType>(elemPtr.getType()).getElementType();
- SmallString<10> name{"getNext", primaryTypeFunctionSuffix(elemTp)};
- SmallVector<Value, 3> params{iter, coords, elemPtr};
- Type i1 = builder.getI1Type();
- return createFuncCall(builder, loc, name, i1, params, EmitCInterface::On)
- .getResult(0);
-}
-
-/// Loads the value stored in `elemPtr`, and stores it at the coordinates
-/// `cvs` into a dense tensor created by `allocDenseTensor`.
-static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
- Value elemPtr, Value tensor,
- ValueRange cvs) {
- Value elemV = builder.create<memref::LoadOp>(loc, elemPtr);
- builder.create<memref::StoreOp>(loc, elemV, tensor, cvs);
-}
-
-/// Determine if the runtime library supports direct conversion to the
-/// given target `dimTypes`.
-static bool canUseDirectConversion(ArrayRef<DimLevelType> dimTypes) {
- bool alreadyCompressed = false;
- for (const auto dlt : dimTypes) {
- if (isCompressedDLT(dlt)) {
- if (alreadyCompressed)
- return false; // Multiple compressed dimensions not yet supported.
- alreadyCompressed = true;
- } else if (isDenseDLT(dlt)) {
- if (alreadyCompressed)
- return false; // Dense after Compressed not yet supported.
- } else if (isSingletonDLT(dlt)) {
- // Direct conversion doesn't have any particular problems with
- // singleton after compressed.
- } else { // TODO: investigate
- return false;
- }
- }
- return true;
-}
-
//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
@@ -540,179 +470,27 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
};
/// Sparse conversion rule for the convert operator.
-class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
+class SparseTensorReorderCOOConverter
+ : public OpConversionPattern<ReorderCOOOp> {
public:
using OpConversionPattern::OpConversionPattern;
- SparseTensorConvertConverter(MLIRContext *context,
- SparseTensorConversionOptions o)
- : OpConversionPattern<ConvertOp>(context), options(o) {}
- SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
- SparseTensorConversionOptions o)
- : OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
LogicalResult
- matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+ matchAndRewrite(ReorderCOOOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op->getLoc();
- const auto srcTp = getSparseTensorType(op.getSource());
+ const auto srcTp = getSparseTensorType(op.getInputCoo());
const auto dstTp = getSparseTensorType(op);
- if (!srcTp.hasEncoding() && !dstTp.hasEncoding())
- return failure();
- const Dimension dimRank = srcTp.getDimRank();
- const Type elemTp = srcTp.getElementType();
- const Value src = adaptor.getOperands()[0];
- if (srcTp.hasEncoding() && dstTp.hasEncoding()) {
- const auto srcEnc = srcTp.getEncoding();
- const auto dstEnc = dstTp.getEncoding();
- // This is a sparse => sparse conversion, which is handled as follows:
- // t = src->toCOO(); ; src to COO in dst order
- // dst = newSparseTensor(t)
- // Using the coordinate scheme as an intermediate does not always
- // yield the fastest conversion but avoids the need for a full
- // O(N^2) conversion matrix.
- if (dstEnc == srcEnc) {
- rewriter.replaceOp(op, adaptor.getOperands()); // hidden nop cast
- return success();
- }
- NewCallParams params(rewriter, loc);
- SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
- bool useDirectConversion;
- switch (options.sparseToSparseStrategy) {
- case SparseToSparseConversionStrategy::kViaCOO:
- useDirectConversion = false;
- break;
- case SparseToSparseConversionStrategy::kDirect:
- useDirectConversion = true;
- assert(canUseDirectConversion(dstEnc.getLvlTypes()) &&
- "Unsupported target for direct sparse-to-sparse conversion");
- break;
- case SparseToSparseConversionStrategy::kAuto:
- useDirectConversion = canUseDirectConversion(dstEnc.getLvlTypes());
- break;
- }
- if (useDirectConversion) {
- rewriter.replaceOp(
- op, params.genBuffers(srcTp.withEncoding(dstEnc), dimSizes)
- .genNewCall(Action::kSparseToSparse, src));
- } else { // use via-COO conversion.
- // Set up encoding with right mix of src and dst so that the two
- // method calls can share most parameters, while still providing
- // the correct sparsity information to either of them.
- const auto mixedEnc =
- dstEnc.withBitWidths(srcEnc.getPosWidth(), srcEnc.getCrdWidth());
- // TODO: This is the only place where `kToCOO` (or `kToIterator`)
- // is called with a non-identity permutation. Is there any clean
- // way to push the permutation over to the `kFromCOO` side instead?
- Value coo = params.genBuffers(srcTp.withEncoding(mixedEnc), dimSizes)
- .genNewCall(Action::kToCOO, src);
- Value dst = params.setTemplateTypes(srcTp.withEncoding(dstEnc))
- .genNewCall(Action::kFromCOO, coo);
- genDelCOOCall(rewriter, loc, elemTp, coo);
- rewriter.replaceOp(op, dst);
- }
- return success();
- }
- if (srcTp.hasEncoding() && !dstTp.hasEncoding()) {
- const auto srcEnc = srcTp.getEncoding();
- // This is sparse => dense conversion, which is handled as follows:
- // dst = new Tensor(0);
- // iter = new SparseTensorIterator(src);
- // while (elem = iter->getNext()) {
- // dst[elem.coords] = elem.value;
- // }
- // delete iter;
- //
- // Fabricate a no-permutation encoding for NewCallParams
- // The position/coordinate types must be those of `src`.
- // The dimLevelTypes aren't actually used by Action::kToIterator.
- const auto dstEnc = SparseTensorEncodingAttr::get(
- op->getContext(),
- SmallVector<DimLevelType>(dimRank, DimLevelType::Dense), AffineMap(),
- AffineMap(), srcEnc.getPosWidth(), srcEnc.getCrdWidth());
- SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, srcTp, src);
- Value iter = NewCallParams(rewriter, loc)
- .genBuffers(dstTp.withEncoding(dstEnc), dimSizes)
- .genNewCall(Action::kToIterator, src);
- const Type iTp = rewriter.getIndexType();
- Value dimCoords = genAlloca(rewriter, loc, dimRank, iTp);
- Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
- // TODO: Dense buffers should be allocated/deallocated via the callback
- // in BufferizationOptions.
- Value dst = allocDenseTensor(rewriter, loc, dstTp, dimSizes);
- const SmallVector<Value> noArgs;
- const SmallVector<Type> noTypes;
- auto whileOp = rewriter.create<scf::WhileOp>(loc, noTypes, noArgs);
- Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, noTypes);
- rewriter.setInsertionPointToEnd(before);
- Value cond = genGetNextCall(rewriter, loc, iter, dimCoords, elemPtr);
- rewriter.create<scf::ConditionOp>(loc, cond, before...
[truncated]
|
yinying-lisa-li
approved these changes
Oct 12, 2023
aartbik
reviewed
Oct 12, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As a side effect of the change, it also unifies the convertOp implementation between lib/codegen path.