Skip to content

Commit f2696e4

Browse files
committed
[mlir][sparse] Cleaning up some usage of SparseTensorType
This is a followup to D147192. Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D147196
1 parent 498aa53 commit f2696e4

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -356,28 +356,22 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
356356
PatternRewriter &rewriter) const override {
357357
Location loc = op.getLoc();
358358
Value srcTensor = op.getSrc();
359-
auto srcTp = getRankedTensorType(srcTensor);
360-
auto dstTp = getRankedTensorType(op.getResult());
361-
362-
SparseTensorType srcStt(srcTp);
363-
SparseTensorType dstStt(dstTp);
364-
365-
const auto encSrc = srcStt.getEncoding();
366-
if (!srcStt.hasEncoding() || !dstStt.hasEncoding()) {
359+
const auto srcTp = getSparseTensorType(srcTensor);
360+
const auto dstTp = getSparseTensorType(op.getResult());
361+
if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
367362
return failure();
368-
}
369363

370364
// Generate code to represent the static dimension constants or compute
371365
// the dynamic dimension values.
372366
SmallVector<Value> srcSizes;
373367
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
374368
SmallVector<Value> dstSizes;
375369
SmallVector<Value> dstDynSizes;
376-
if (dstTp.hasStaticShape()) {
377-
for (auto d : dstTp.getShape())
370+
if (dstTp.hasStaticDimShape()) {
371+
for (Dimension d : dstTp.getDimShape())
378372
dstSizes.push_back(constantIndex(rewriter, loc, d));
379373
} else {
380-
ArrayRef<int64_t> dstShape = dstTp.getShape();
374+
ArrayRef<DynSize> dstShape = dstTp.getDimShape();
381375
genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
382376
op.getReassociationIndices());
383377
for (auto [idx, shape] : llvm::enumerate(dstShape)) {
@@ -389,8 +383,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
389383
// Only need a unordered COO buffer if input and output are not sorted
390384
// in the same way.
391385
Type bufferTp =
392-
srcStt.isAllOrdered() && srcStt.isIdentity() && dstStt.isIdentity()
393-
? dstTp
386+
srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
387+
? dstTp.getRankedTensorType()
394388
: getUnorderedCOOFromType(dstTp);
395389

396390
Value buffer =
@@ -406,11 +400,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
406400
// followed by an optional
407401
// %t = sparse_tensor.cast %tmp
408402
// depending on whether the input/output are sorted in the same way.
403+
const auto encSrc = srcTp.getEncoding();
409404
ForeachOp foreachOp = rewriter.create<ForeachOp>(
410405
loc, srcTensor, buffer,
411406
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
412407
ValueRange reduc) {
413-
const Dimension dimRank = srcTp.getRank();
408+
const Dimension dimRank = srcTp.getDimRank();
414409
SmallVector<Value> srcDcvs;
415410
srcDcvs.reserve(dimRank);
416411
for (Dimension d = 0; d < dimRank; d++) {
@@ -427,7 +422,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
427422

428423
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
429424
if (bufferTp != dstTp) {
430-
Value converted = rewriter.create<ConvertOp>(loc, dstTp, t).getResult();
425+
auto dstRTT = dstTp.getRankedTensorType();
426+
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
431427
rewriter.create<DeallocTensorOp>(loc, t);
432428
t = converted;
433429
}

0 commit comments

Comments
 (0)