Skip to content

Commit 8e01e2e

Browse files
[mlir][Vector] Fold tensor_cast + vector.transfer_read
Differential Revision: https://reviews.llvm.org/D96988
1 parent 0c087a6 commit 8e01e2e

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "mlir/Dialect/Vector/VectorOps.h"
1515
#include "mlir/Dialect/StandardOps/IR/Ops.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1617
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1718
#include "mlir/Dialect/Vector/VectorUtils.h"
1819
#include "mlir/IR/AffineExpr.h"
@@ -2408,6 +2409,18 @@ static LogicalResult foldMemRefCast(Operation *op) {
24082409
return success(folded);
24092410
}
24102411

2412+
static LogicalResult foldTensorCast(Operation *op) {
2413+
bool folded = false;
2414+
for (OpOperand &operand : op->getOpOperands()) {
2415+
auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
2416+
if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
2417+
operand.set(castOp.getOperand());
2418+
folded = true;
2419+
}
2420+
}
2421+
return success(folded);
2422+
}
2423+
24112424
template <typename TransferOp>
24122425
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
24132426
// TODO: support more aggressive createOrFold on:
@@ -2460,6 +2473,8 @@ OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
24602473
return getResult();
24612474
if (succeeded(foldMemRefCast(*this)))
24622475
return getResult();
2476+
if (succeeded(foldTensorCast(*this)))
2477+
return getResult();
24632478
return OpFoldResult();
24642479
}
24652480

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) {
267267

268268
// -----
269269

270+
// CHECK-LABEL: cast_transfers
271+
func @cast_transfers(%A: tensor<4x8xf32>) -> (vector<4x8xf32>) {
272+
%c0 = constant 0 : index
273+
%f0 = constant 0.0 : f32
274+
%0 = tensor.cast %A : tensor<4x8xf32> to tensor<?x?xf32>
275+
276+
// CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : tensor<4x8xf32>, vector<4x8xf32>
277+
%1 = vector.transfer_read %0[%c0, %c0], %f0 : tensor<?x?xf32>, vector<4x8xf32>
278+
279+
return %1 : vector<4x8xf32>
280+
}
281+
282+
// -----
283+
270284
// CHECK-LABEL: func @insert_extract_transpose_2d(
271285
// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>,
272286
// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,

0 commit comments

Comments
 (0)