@@ -356,28 +356,22 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
356
356
PatternRewriter &rewriter) const override {
357
357
Location loc = op.getLoc ();
358
358
Value srcTensor = op.getSrc ();
359
- auto srcTp = getRankedTensorType (srcTensor);
360
- auto dstTp = getRankedTensorType (op.getResult ());
361
-
362
- SparseTensorType srcStt (srcTp);
363
- SparseTensorType dstStt (dstTp);
364
-
365
- const auto encSrc = srcStt.getEncoding ();
366
- if (!srcStt.hasEncoding () || !dstStt.hasEncoding ()) {
359
+ const auto srcTp = getSparseTensorType (srcTensor);
360
+ const auto dstTp = getSparseTensorType (op.getResult ());
361
+ if (!srcTp.hasEncoding () || !dstTp.hasEncoding ())
367
362
return failure ();
368
- }
369
363
370
364
// Generate code to represent the static dimension constants or compute
371
365
// the dynamic dimension values.
372
366
SmallVector<Value> srcSizes;
373
367
sizesForTensor (rewriter, srcSizes, loc, srcTp, srcTensor);
374
368
SmallVector<Value> dstSizes;
375
369
SmallVector<Value> dstDynSizes;
376
- if (dstTp.hasStaticShape ()) {
377
- for (auto d : dstTp.getShape ())
370
+ if (dstTp.hasStaticDimShape ()) {
371
+ for (Dimension d : dstTp.getDimShape ())
378
372
dstSizes.push_back (constantIndex (rewriter, loc, d));
379
373
} else {
380
- ArrayRef<int64_t > dstShape = dstTp.getShape ();
374
+ ArrayRef<DynSize > dstShape = dstTp.getDimShape ();
381
375
genReshapeDstShape (loc, rewriter, dstSizes, srcSizes, dstShape,
382
376
op.getReassociationIndices ());
383
377
for (auto [idx, shape] : llvm::enumerate (dstShape)) {
@@ -389,8 +383,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
389
383
// Only need a unordered COO buffer if input and output are not sorted
390
384
// in the same way.
391
385
Type bufferTp =
392
- srcStt .isAllOrdered () && srcStt .isIdentity () && dstStt .isIdentity ()
393
- ? dstTp
386
+ srcTp .isAllOrdered () && srcTp .isIdentity () && dstTp .isIdentity ()
387
+ ? dstTp. getRankedTensorType ()
394
388
: getUnorderedCOOFromType (dstTp);
395
389
396
390
Value buffer =
@@ -406,11 +400,12 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
406
400
// followed by an optional
407
401
// %t = sparse_tensor.cast %tmp
408
402
// depending on whether the input/output are sorted in the same way.
403
+ const auto encSrc = srcTp.getEncoding ();
409
404
ForeachOp foreachOp = rewriter.create <ForeachOp>(
410
405
loc, srcTensor, buffer,
411
406
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
412
407
ValueRange reduc) {
413
- const Dimension dimRank = srcTp.getRank ();
408
+ const Dimension dimRank = srcTp.getDimRank ();
414
409
SmallVector<Value> srcDcvs;
415
410
srcDcvs.reserve (dimRank);
416
411
for (Dimension d = 0 ; d < dimRank; d++) {
@@ -427,7 +422,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
427
422
428
423
Value t = rewriter.create <LoadOp>(loc, foreachOp.getResult (0 ), true );
429
424
if (bufferTp != dstTp) {
430
- Value converted = rewriter.create <ConvertOp>(loc, dstTp, t).getResult ();
425
+ auto dstRTT = dstTp.getRankedTensorType ();
426
+ Value converted = rewriter.create <ConvertOp>(loc, dstRTT, t).getResult ();
431
427
rewriter.create <DeallocTensorOp>(loc, t);
432
428
t = converted;
433
429
}
0 commit comments