Skip to content

Commit 85ff270

Browse files
[mlir][std] Add DimOp folding for dim(tensor_load(m)) -> dim(m).
Differential Revision: https://reviews.llvm.org/D90755
1 parent 1664462 commit 85ff270

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,23 +1561,29 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
15611561
}
15621562
}
15631563

1564+
Operation *definingOp = memrefOrTensor().getDefiningOp();
1565+
// dim(tensor_load(memref)) -> dim(memref)
1566+
if (auto tensorLoadOp = dyn_cast_or_null<TensorLoadOp>(definingOp)) {
1567+
setOperand(0, tensorLoadOp.memref());
1568+
return getResult();
1569+
}
1570+
15641571
// Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
15651572
auto memrefType = argTy.dyn_cast<MemRefType>();
15661573
if (!memrefType)
15671574
return {};
15681575

15691576
// The size at the given index is now known to be a dynamic size of a memref.
1570-
auto *memref = memrefOrTensor().getDefiningOp();
15711577
unsigned unsignedIndex = index.getValue().getZExtValue();
1572-
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
1578+
if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
15731579
return *(alloc.getDynamicSizes().begin() +
15741580
memrefType.getDynamicDimIndex(unsignedIndex));
15751581

1576-
if (auto view = dyn_cast_or_null<ViewOp>(memref))
1582+
if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
15771583
return *(view.getDynamicSizes().begin() +
15781584
memrefType.getDynamicDimIndex(unsignedIndex));
15791585

1580-
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
1586+
if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
15811587
assert(subview.isDynamicSize(unsignedIndex) &&
15821588
"Expected dynamic subview size");
15831589
return subview.getDynamicSize(unsignedIndex);

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,16 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
3131
%1 = tensor_to_memref %0 : memref<?xf32, 7>
3232
return %1 : memref<?xf32, 7>
3333
}
34+
35+
// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
36+
// CHECK-LABEL: func @dim_of_tensor_load(
37+
// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
38+
// CHECK: %[[C0:.*]] = constant 0
39+
// CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]]
40+
// CHECK: return %[[D]] : index
41+
func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
42+
%c0 = constant 0 : index
43+
%0 = tensor_load %arg0 : memref<?xf32>
44+
%1 = dim %0, %c0 : tensor<?xf32>
45+
return %1 : index
46+
}

0 commit comments

Comments
 (0)