Skip to content

Commit 1e0966c

Browse files
committed
[mlir][sparse] add util for ToCoordinatesBuffer for COO AoS
Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D150382
1 parent 677f7cc commit 1e0966c

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,14 @@ Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
679679
builder.getIndexAttr(lvl));
680680
}
681681

682+
Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
683+
Value tensor) {
684+
const auto srcTp = getSparseTensorType(tensor);
685+
const Type crdTp = srcTp.getEncoding().getCrdType();
686+
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
687+
return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
688+
}
689+
682690
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
683691
Value tensor) {
684692
RankedTensorType srcTp = getRankedTensorType(tensor);

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
364364
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
365365
Level lvl, Level cooStart);
366366

367+
/// Infers the result type and generates `ToCoordinatesBufferOp`.
368+
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
369+
367370
/// Infers the result type and generates `ToValuesOp`.
368371
Value genToValues(OpBuilder &builder, Location loc, Value tensor);
369372

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -895,9 +895,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
895895
// coordinates for the storage ordering of the dst tensor. Use SortCoo
896896
// if the COO tensor has the same ordering as the dst tensor.
897897
if (dimRank > 1 && srcTp.hasSameDimToLvlMap(dstTp)) {
898-
MemRefType coordsTp =
899-
get1DMemRefType(encSrc.getCrdType(), /*withLayout=*/false);
900-
Value xs = rewriter.create<ToCoordinatesBufferOp>(loc, coordsTp, src);
898+
Value xs = genToCoordinatesBuffer(rewriter, loc, src);
901899
rewriter.create<SortCooOp>(
902900
loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
903901
rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);

0 commit comments

Comments
 (0)