Skip to content

Commit c84061f

Browse files
authored
[mlir][vector] Fix a target-rank=0 unrolling (#73365)
Fixes #64269. With this patch, calling `mlir-opt "-convert-vector-to-scf=full-unroll target-rank=0"` on ```mlir func.func @main(%vec : vector<2xi32>) { %alloc = memref.alloc() : memref<4xi32> %c0 = arith.constant 0 : index vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32> return } ``` will result in ```mlir module { func.func @main(%arg0: vector<2xi32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %alloc = memref.alloc() : memref<4xi32> %0 = vector.extract %arg0[0] : i32 from vector<2xi32> %1 = vector.broadcast %0 : i32 to vector<i32> vector.transfer_write %1, %alloc[%c0] : vector<i32>, memref<4xi32> %2 = vector.extract %arg0[1] : i32 from vector<2xi32> %3 = vector.broadcast %2 : i32 to vector<i32> vector.transfer_write %3, %alloc[%c1] : vector<i32>, memref<4xi32> return } } ``` I've also tried to proactively find other `target-rank=0` bugs, but couldn't find any. `options.targetRank` is only used 8 times throughout the `mlir` folder, all inside `VectorToSCF.cpp`. None of the other uses look like they could cause a crash. I've also tried ```mlir func.func @main(%vec : vector<2xi32>) -> vector<2xi32> { %alloc = memref.alloc() : memref<4xindex> %c0 = arith.constant 0 : index %out = vector.transfer_read %alloc[%c0], %c0 : memref<4xindex>, vector<2xi32> return %out : vector<2xi32> } ``` with `"--convert-vector-to-scf=full-unroll target-rank=0"` and that also didn't crash. (Maybe obvious. I have to admit that I'm not very familiar with these ops.)
1 parent 2425e29 commit c84061f

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2222
#include "mlir/Dialect/SCF/IR/SCF.h"
2323
#include "mlir/Dialect/Tensor/IR/Tensor.h"
24+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2425
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2526
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2627
#include "mlir/IR/Builders.h"
@@ -1207,23 +1208,25 @@ struct UnrollTransferWriteConversion
12071208
/// accesses, and broadcasts and transposes in permutation maps.
12081209
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
12091210
PatternRewriter &rewriter) const override {
1210-
if (xferOp.getVectorType().getRank() <= options.targetRank)
1211+
VectorType inputVectorTy = xferOp.getVectorType();
1212+
1213+
if (inputVectorTy.getRank() <= options.targetRank)
12111214
return failure();
1215+
12121216
if (isTensorOp(xferOp) && !options.lowerTensors)
12131217
return failure();
12141218
// Transfer ops that modify the element type are not supported atm.
1215-
if (xferOp.getVectorType().getElementType() !=
1219+
if (inputVectorTy.getElementType() !=
12161220
xferOp.getShapedType().getElementType())
12171221
return failure();
12181222

12191223
auto vec = getDataVector(xferOp);
1220-
auto xferVecType = xferOp.getVectorType();
1221-
if (xferVecType.getScalableDims()[0]) {
1224+
if (inputVectorTy.getScalableDims()[0]) {
12221225
// Cannot unroll a scalable dimension at compile time.
12231226
return failure();
12241227
}
12251228

1226-
int64_t dimSize = xferVecType.getShape()[0];
1229+
int64_t dimSize = inputVectorTy.getShape()[0];
12271230
Value source = xferOp.getSource(); // memref or tensor to be written to.
12281231
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
12291232

@@ -1249,8 +1252,18 @@ struct UnrollTransferWriteConversion
12491252
auto extracted =
12501253
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
12511254
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1255+
Value xferVec;
1256+
if (inputVectorTy.getRank() == 1) {
1257+
// When target-rank=0, unrolling would causes the vector input
1258+
// argument into `transfer_write` to become a scalar. We solve
1259+
// this by broadcasting the scalar to a 0D vector.
1260+
xferVec = b.create<vector::BroadcastOp>(
1261+
loc, VectorType::get({}, extracted.getType()), extracted);
1262+
} else {
1263+
xferVec = extracted;
1264+
}
12521265
auto newXferOp = b.create<vector::TransferWriteOp>(
1253-
loc, sourceType, extracted, source, xferIndices,
1266+
loc, sourceType, xferVec, source, xferIndices,
12541267
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
12551268
inBoundsAttr);
12561269

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
22
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
3+
// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO
34

45
// CHECK-LABEL: func @vector_transfer_ops_0d(
56
func.func @vector_transfer_ops_0d(%M: memref<f32>) {
@@ -748,3 +749,20 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
748749
vector.transfer_write %vec, %memref[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
749750
return
750751
}
752+
753+
// -----
754+
755+
// TARGET-RANK-ZERO-LABEL: func @unroll_transfer_write_target_rank_zero
756+
func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
757+
%alloc = memref.alloc() : memref<4xi32>
758+
%c0 = arith.constant 0 : index
759+
vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
760+
return
761+
}
762+
// TARGET-RANK-ZERO: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32>
763+
// TARGET-RANK-ZERO: %[[EXTRACTED1:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
764+
// TARGET-RANK-ZERO: %[[BROADCASTED1:.*]] = vector.broadcast %[[EXTRACTED1]] : i32 to vector<i32>
765+
// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED1]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
766+
// TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
767+
// TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
768+
// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>

0 commit comments

Comments
 (0)