Skip to content

Commit c7a3346

Browse files
authored
[mlir][linalg] Fix scalable vectorisation of tensor.extract (llvm#100325)
This PR fixes one very specific aspect of vectorising `tensor.extract` Ops when targeting scalable vectors. Namely, it makes sure that the scalable flag is correctly propagated when creating `vector::ShapeCastOp`. BEFORE: ```mlir vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<4xindex> ``` AFTER: ```mlir vector.shape_cast %idx_vec : vector<1x1x[4]xindex> to vector<[4]xindex> ``` This particular ShapeCastOp is created when generating an index for `vector.transfer_read` operations. Strictly speaking, casting is not really required. However, it makes the subsequent address calculation much simpler (*). The following test is updated to demonstrate the use of `vector.shape_cast` by the vectoriser: * @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test with scalable vectors is also added. (*) At this point in the vectoriser it is known that all leading dims in the index vector are "1").
1 parent d82df1b commit c7a3346

File tree

2 files changed

+103
-17
lines changed

2 files changed

+103
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,19 +1077,20 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10771077
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
10781078
// (0th) element and use that.
10791079
SmallVector<Value> transferReadIdxs;
1080-
auto resTrailingDim = resultType.getShape().back();
10811080
auto zero = rewriter.create<arith::ConstantOp>(
10821081
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
10831082
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
1084-
auto idx = bvm.lookup(extractOp.getIndices()[i]);
1083+
Value idx = bvm.lookup(extractOp.getIndices()[i]);
10851084
if (idx.getType().isIndex()) {
10861085
transferReadIdxs.push_back(idx);
10871086
continue;
10881087
}
10891088

10901089
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1091-
loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
1092-
bvm.lookup(extractOp.getIndices()[i]));
1090+
loc,
1091+
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
1092+
resultType.getScalableDims().back()),
1093+
idx);
10931094
transferReadIdxs.push_back(
10941095
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
10951096
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,52 @@
11
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
22

3-
func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
3+
func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
4+
%src: tensor<80x16xf32>,
5+
%output : tensor<1x3xf32>,
6+
%idx: index) -> tensor<1x3xf32> {
7+
48
%c79 = arith.constant 79 : index
59
%1 = linalg.generic {
610
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
711
iterator_types = ["parallel", "parallel"]
8-
} outs(%extracted_slice : tensor<1x3xf32>) {
12+
} outs(%output : tensor<1x3xf32>) {
913
^bb0(%out: f32):
1014
%2 = linalg.index 1 : index
11-
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
12-
%extracted = tensor.extract %6[%c79, %3] : tensor<80x16xf32>
15+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
16+
%extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
1317
linalg.yield %extracted : f32
1418
} -> tensor<1x3xf32>
1519
return %1 : tensor<1x3xf32>
1620
}
1721

1822
// CHECK-LABEL: func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous
19-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index
20-
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 3 : index
21-
// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_5]] : vector<1x4xi1>
22-
// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
23-
// CHECK: %[[VAL_11:.*]] = vector.broadcast {{.*}} : index to vector<4xindex>
24-
// CHECK: %[[VAL_12:.*]] = arith.addi {{.*}} : vector<4xindex>
25-
// CHECK: %[[VAL_20:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
26-
// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_8]] { vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
23+
// CHECK-SAME: %[[SRC:.*]]: tensor<80x16xf32>,
24+
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x3xf32>,
25+
// CHECK-SAME: %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
26+
27+
/// Create the mask
28+
// CHECK-DAG: %[[DIM_0:.*]] = arith.constant 1 : index
29+
// CHECK-DAG: %[[DIM_1:.*]] = arith.constant 3 : index
30+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
31+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
32+
33+
/// TODO: This transfer_read is redundant - remove
34+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
35+
36+
/// Caluclate the index vector
37+
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
38+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
39+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
40+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
41+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
42+
43+
/// Extract the starting point from the index vector
44+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
45+
46+
// Final read and write
47+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
48+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
49+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x3xf32> } : vector<1x4xi1> -> tensor<1x3xf32>
2750

2851
module attributes {transform.with_named_sequence} {
2952
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -33,7 +56,69 @@ module attributes {transform.with_named_sequence} {
3356
}
3457
}
3558

36-
// -----
59+
// -----
60+
61+
// Identical to the above, but with scalable vectors.
62+
63+
func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
64+
%src: tensor<80x16xf32>,
65+
%output : tensor<1x3xf32>,
66+
%idx: index) -> tensor<1x3xf32> {
67+
68+
%c79 = arith.constant 79 : index
69+
%1 = linalg.generic {
70+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
71+
iterator_types = ["parallel", "parallel"]
72+
} outs(%output : tensor<1x3xf32>) {
73+
^bb0(%out: f32):
74+
%2 = linalg.index 1 : index
75+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
76+
%extracted = tensor.extract %src[%c79, %3] : tensor<80x16xf32>
77+
linalg.yield %extracted : f32
78+
} -> tensor<1x3xf32>
79+
80+
return %1 : tensor<1x3xf32>
81+
}
82+
83+
// CHECK-LABEL: func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable
84+
// CHECK-SAME: %[[SRC:.*]]: tensor<80x16xf32>,
85+
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<1x3xf32>,
86+
// CHECK-SAME: %[[IDX_IN:.*]]: index) -> tensor<1x3xf32> {
87+
88+
/// Create the mask
89+
// CHECK-DAG: %[[DIM_0:.*]] = arith.constant 1 : index
90+
// CHECK-DAG: %[[DIM_1:.*]] = arith.constant 3 : index
91+
// CHECK-DAG: %[[C79:.*]] = arith.constant 79 : index
92+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
93+
94+
/// TODO: This transfer_read is redundant - remove
95+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<1x3xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
96+
97+
/// Caluclate the index vector
98+
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
99+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
100+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
101+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
102+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
103+
104+
/// Extract the starting point from the index vector
105+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
106+
107+
// Final read and write
108+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
109+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
110+
// CHECK: vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{\[}}%[[C0_1]], %[[C0_1]]] {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<1x3xf32> } : vector<1x[4]xi1> -> tensor<1x3xf32>
111+
112+
113+
module attributes {transform.with_named_sequence} {
114+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
115+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
116+
transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
117+
transform.yield
118+
}
119+
}
120+
121+
// -----
37122

38123
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
39124
%c79 = arith.constant 79 : index

0 commit comments

Comments
 (0)