Skip to content

[mlir][Vector] Support 0-d vectors natively in TransferOpReduceRank #112907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,31 +358,10 @@ struct TransferOpReduceRank
op, "map is not a minor identity with broadcasting");
}

// TODO: support zero-dimension vectors natively. See:
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
if (isa<TensorType>(op.getShapedType())) {
newRead = rewriter.create<tensor::ExtractOp>(
op.getLoc(), op.getSource(), op.getIndices());
} else {
newRead = rewriter.create<memref::LoadOp>(
op.getLoc(), originalVecType.getElementType(), op.getSource(),
op.getIndices());
}
return rewriter
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
.getVector();
}

SmallVector<int64_t> newShape(
originalVecType.getShape().take_back(reducedShapeRank));
SmallVector<bool> newScalableDims(
originalVecType.getScalableDims().take_back(reducedShapeRank));
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
if (newShape.empty())
return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");

VectorType newReadType = VectorType::get(
newShape, originalVecType.getElementType(), newScalableDims);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ func.func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.to

// CHECK-LABEL: transfer_read_with_tensor
func.func @transfer_read_with_tensor(%arg: tensor<f32>) -> vector<1xf32> {
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[] : tensor<f32>
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : f32 to vector<1xf32>
// CHECK: %[[EXTRACTED:.*]] = vector.transfer_read %{{.*}}[], %{{.*}} : tensor<f32>, vector<f32>
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : vector<f32> to vector<1xf32>
// CHECK-NEXT: return %[[RESULT]] : vector<1xf32>
%f0 = arith.constant 0.0 : f32
%0 = vector.transfer_read %arg[], %f0 {permutation_map = affine_map<()->(0)>} :
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>

// Same as example above, but reading into a column tensor. Note that after the
// vectorizatoin, the `TransferOpReduceRank` will replace
// `vector.transfer_read` with `tensor.extract -> scalar`.
// Same as example above, but reading into a column tensor.

// TODO: Currently this fails to vectorise when the indices are non-constant.

Expand All @@ -135,9 +133,10 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[CST_0]] : tensor<3x3x3xf32>, vector<f32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[READ]] : vector<f32> to vector<3x1x1xf32>
// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
// CHECK: return %[[RES]] : tensor<3x1x1xf32>

Expand Down Expand Up @@ -514,8 +513,9 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
// First `tensor.extract` from the generic Op - loop invariant scalar load.
// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
// First `vector.transfer_read` from the generic Op - loop invariant scalar load.
// CHECK: vector.transfer_read %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]]
// CHECK-SAME: tensor<1x20xi32>, vector<i32>
// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
// for address calculation also satisfy the required conditions).
// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
Expand Down Expand Up @@ -718,8 +718,8 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf

// CHECK-LABEL: func.func @vectorize_0d_tensor_extract(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[ARG_0]][], %{{.+}} : tensor<f32>
// CHECK: vector.broadcast %[[EXTRACT]] : vector<f32> to vector<1x1x3xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SRC]][] : tensor<f32>
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32>
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
tensor<f32>, vector<1xf32>

Expand Down
Loading