Skip to content

[mlir][vector] Fix a target-rank=0 unrolling #73365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1207,23 +1208,25 @@ struct UnrollTransferWriteConversion
/// accesses, and broadcasts and transposes in permutation maps.
LogicalResult matchAndRewrite(TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
VectorType inputVectorTy = xferOp.getVectorType();

if (inputVectorTy.getRank() <= options.targetRank)
return failure();

if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
if (inputVectorTy.getElementType() !=
xferOp.getShapedType().getElementType())
return failure();

auto vec = getDataVector(xferOp);
auto xferVecType = xferOp.getVectorType();
if (xferVecType.getScalableDims()[0]) {
if (inputVectorTy.getScalableDims()[0]) {
// Cannot unroll a scalable dimension at compile time.
return failure();
}

int64_t dimSize = xferVecType.getShape()[0];
int64_t dimSize = inputVectorTy.getShape()[0];
Value source = xferOp.getSource(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();

Expand All @@ -1249,8 +1252,18 @@ struct UnrollTransferWriteConversion
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
Value xferVec;
if (inputVectorTy.getRank() == 1) {
// When target-rank=0, unrolling would causes the vector input
// argument into `transfer_write` to become a scalar. We solve
// this by broadcasting the scalar to a 0D vector.
xferVec = b.create<vector::BroadcastOp>(
loc, VectorType::get({}, extracted.getType()), extracted);
} else {
xferVec = extracted;
}
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, sourceType, extracted, source, xferIndices,
loc, sourceType, xferVec, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);

Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf))" -split-input-file -allow-unregistered-dialect | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf{full-unroll=true}))" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=FULL-UNROLL
// RUN: mlir-opt %s "-convert-vector-to-scf=full-unroll target-rank=0" -split-input-file -allow-unregistered-dialect | FileCheck %s --check-prefix=TARGET-RANK-ZERO

// CHECK-LABEL: func @vector_transfer_ops_0d(
func.func @vector_transfer_ops_0d(%M: memref<f32>) {
Expand Down Expand Up @@ -748,3 +749,20 @@ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector
vector.transfer_write %vec, %memref[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// -----

// TARGET-RANK-ZERO-LABEL: func @unroll_transfer_write_target_rank_zero
func.func @unroll_transfer_write_target_rank_zero(%vec : vector<2xi32>) {
%alloc = memref.alloc() : memref<4xi32>
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %alloc[%c0] : vector<2xi32>, memref<4xi32>
return
}
// TARGET-RANK-ZERO: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32>
// TARGET-RANK-ZERO: %[[EXTRACTED1:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
// TARGET-RANK-ZERO: %[[BROADCASTED1:.*]] = vector.broadcast %[[EXTRACTED1]] : i32 to vector<i32>
// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED1]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>
// TARGET-RANK-ZERO: %[[EXTRACTED2:.*]] = vector.extract {{.*}} : i32 from vector<2xi32>
// TARGET-RANK-ZERO: %[[BROADCASTED2:.*]] = vector.broadcast %[[EXTRACTED2]] : i32 to vector<i32>
// TARGET-RANK-ZERO: vector.transfer_write %[[BROADCASTED2]], %[[ALLOC]]{{.*}} : vector<i32>, memref<4xi32>