Skip to content

Commit a8f3d30

Browse files
authored
[mlir] Add dependent TensorDialect to ConvertVectorToLLVM pass (llvm#108045)
This patch registers the tensor dialect as dependent of the ConvertVectorToLLVM. This which fixes a crash when `vector.transfer_write` is used with dynamic tensor type. The MaterializeTransferMask pattern would call `vector::createOrFoldDimOp` which creates a `tensor.dim` operation. Fixes llvm#107805.
1 parent 596e7cc commit a8f3d30

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2223
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2324
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2425
#include "mlir/Dialect/X86Vector/Transforms.h"
@@ -45,6 +46,7 @@ struct ConvertVectorToLLVMPass
4546
registry.insert<LLVM::LLVMDialect>();
4647
registry.insert<arith::ArithDialect>();
4748
registry.insert<memref::MemRefDialect>();
49+
registry.insert<tensor::TensorDialect>();
4850
if (armNeon)
4951
registry.insert<arm_neon::ArmNeonDialect>();
5052
if (armSVE)

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,16 @@ func.func @transfer_write_1d_scalable_mask(%arg0: memref<1x?xf32>, %vec: vector<
25212521

25222522
// -----
25232523

2524+
// CHECK-LABEL: func @transfer_write_tensor
2525+
// CHECK: vector.transfer_write
2526+
func.func @transfer_write_tensor(%arg0: vector<4xf32>,%arg1: tensor<?xf32>) -> tensor<?xf32> {
2527+
%c0 = arith.constant 0 : index
2528+
%0 = vector.transfer_write %arg0, %arg1[%c0] : vector<4xf32>, tensor<?xf32>
2529+
return %0 : tensor<?xf32>
2530+
}
2531+
2532+
// -----
2533+
25242534
func.func @genbool_0d_f() -> vector<i1> {
25252535
%0 = vector.constant_mask [0] : vector<i1>
25262536
return %0 : vector<i1>

0 commit comments

Comments
 (0)