Skip to content

Commit de03121

Browse files
author
MaheshRavishankar
committed
[mlir] Add canonicalization from tensor_cast to dim op.
Fold a `tensor_cast` -> `dim` to take the `dim` of the original tensor. Differential Revision: https://reviews.llvm.org/D93492
1 parent 3d56644 commit de03121

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1472,11 +1472,29 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
14721472
return success();
14731473
}
14741474
};
1475+
1476+
/// Fold dim of a dim of a cast into the the dim of the source of the tensor
1477+
/// cast.
1478+
template <typename CastOpTy>
1479+
struct DimOfCastOp : public OpRewritePattern<DimOp> {
1480+
using OpRewritePattern<DimOp>::OpRewritePattern;
1481+
1482+
LogicalResult matchAndRewrite(DimOp dimOp,
1483+
PatternRewriter &rewriter) const override {
1484+
auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>();
1485+
if (!castOp)
1486+
return failure();
1487+
Value newSource = castOp.getOperand();
1488+
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index());
1489+
return success();
1490+
}
1491+
};
1492+
14751493
} // end anonymous namespace.
14761494

14771495
void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
14781496
MLIRContext *context) {
1479-
results.insert<DimOfMemRefReshape>(context);
1497+
results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context);
14801498
}
14811499

14821500
// ---------------------------------------------------------------------------

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,19 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
115115
%1 = dim %0, %c3 : memref<*xf32>
116116
return %1 : index
117117
}
118+
119+
// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx
120+
// CHECK-LABEL: func @fold_dim_of_tensor_cast
121+
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
122+
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
123+
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
124+
// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]]
125+
// CHECK-NEXT: return %[[C4]], %[[T0]]
126+
func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
127+
%c0 = constant 0 : index
128+
%c1 = constant 1 : index
129+
%0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
130+
%1 = dim %0, %c0 : tensor<?x?xf32>
131+
%2 = dim %0, %c1 : tensor<?x?xf32>
132+
return %1, %2: index, index
133+
}

0 commit comments

Comments
 (0)