Skip to content

Commit a1d86c8

Browse files
[Linalg][Vectorization] Add support for linalg vectorization case with outer non unit dim
1 parent bdf0224 commit a1d86c8

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,6 +1079,32 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10791079
continue;
10801080
}
10811081

1082+
auto idxType = dyn_cast<VectorType>(idx.getType());
1083+
1084+
if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
1085+
auto maxElement = std::max_element(resultType.getShape().begin(),
1086+
resultType.getShape().end());
1087+
auto maxElementDim =
1088+
std::distance(resultType.getShape().begin(), maxElement);
1089+
// This means that the result type is not all unit dims expect innermost
1090+
// dim and we insert a transpose op to make it all unit dims expect
1091+
// innermost dim.
1092+
if (maxElementDim != resultType.getShape().size() - 1) {
1093+
SmallVector<int64_t> transposition = llvm::to_vector<16>(
1094+
llvm::seq<int64_t>(0, resultType.getShape().size()));
1095+
std::swap(transposition.back(), transposition[maxElementDim]);
1096+
auto transposeOp =
1097+
rewriter.create<vector::TransposeOp>(loc, idx, transposition);
1098+
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
1099+
loc,
1100+
VectorType::get(*maxElement, rewriter.getIndexType(),
1101+
resultType.getScalableDims().back()),
1102+
transposeOp);
1103+
transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
1104+
loc, indexAs1dVector, zero));
1105+
continue;
1106+
}
1107+
}
10821108
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
10831109
loc,
10841110
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
253253
transform.yield
254254
}
255255
}
256+
257+
// -----
258+
259+
#map = affine_map<(d0, d1) -> (d0, d1)>
260+
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
261+
func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
262+
%c0 = arith.constant 0 : index
263+
%0 = tensor.empty() : tensor<8x1xf32>
264+
%1 = linalg.generic {
265+
indexing_maps = [#map],
266+
iterator_types = ["parallel", "parallel"]
267+
} outs(%0 : tensor<8x1xf32>) {
268+
^bb0(%arg5: f32):
269+
%2 = linalg.index 0 : index
270+
%3 = linalg.index 1 : index
271+
%4 = affine.apply #map1(%arg1, %3, %arg1)
272+
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
273+
linalg.yield %extracted : f32
274+
} -> tensor<8x1xf32>
275+
return %1 : tensor<8x1xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
280+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
281+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
282+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
283+
transform.yield
284+
}
285+
}
286+
287+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
288+
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289+
// CHECK-SAME: %[[ARG1:.*]]: index
290+
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
291+
// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
292+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
293+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
294+
// CHECK: %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
295+
// CHECK: %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
296+
// CHECK: %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
297+
// CHECK: %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
298+
// CHECK: %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
299+
// CHECK: %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
300+
// CHECK: %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
301+
// CHECK: %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
302+
// CHECK: vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
303+
256304
// -----
257305

258306
#map = affine_map<(d0) -> (d0)>

0 commit comments

Comments
 (0)