@@ -1561,23 +1561,29 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
1561
1561
}
1562
1562
}
1563
1563
1564
+ Operation *definingOp = memrefOrTensor ().getDefiningOp ();
1565
+ // dim(tensor_load(memref)) -> dim(memref)
1566
+ if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
1567
+ setOperand (0 , tensorLoadOp.memref ());
1568
+ return getResult ();
1569
+ }
1570
+
1564
1571
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1565
1572
auto memrefType = argTy.dyn_cast <MemRefType>();
1566
1573
if (!memrefType)
1567
1574
return {};
1568
1575
1569
1576
// The size at the given index is now known to be a dynamic size of a memref.
1570
- auto *memref = memrefOrTensor ().getDefiningOp ();
1571
1577
unsigned unsignedIndex = index.getValue ().getZExtValue ();
1572
- if (auto alloc = dyn_cast_or_null<AllocOp>(memref ))
1578
+ if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp ))
1573
1579
return *(alloc.getDynamicSizes ().begin () +
1574
1580
memrefType.getDynamicDimIndex (unsignedIndex));
1575
1581
1576
- if (auto view = dyn_cast_or_null<ViewOp>(memref ))
1582
+ if (auto view = dyn_cast_or_null<ViewOp>(definingOp ))
1577
1583
return *(view.getDynamicSizes ().begin () +
1578
1584
memrefType.getDynamicDimIndex (unsignedIndex));
1579
1585
1580
- if (auto subview = dyn_cast_or_null<SubViewOp>(memref )) {
1586
+ if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp )) {
1581
1587
assert (subview.isDynamicSize (unsignedIndex) &&
1582
1588
" Expected dynamic subview size" );
1583
1589
return subview.getDynamicSize (unsignedIndex);
0 commit comments