Skip to content

Commit b88f8e3

Browse files
[Linalg][Vectorization] Add support for linalg vectorization case with outer non unit dim
1 parent 3b3accb commit b88f8e3

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,25 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810810

811811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812812

813+
814+
/// Find the non constant dim in a linalgOp. This is used for finding contiguous
815+
/// loads and it is expected that only one dim will be non constant, if thats
816+
/// not the case this function will assert.
817+
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
818+
uint64_t nonUnitDim = 0;
819+
uint64_t countNonUnitDim = 0;
820+
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
821+
if (tripCount.value() != 1) {
822+
nonUnitDim = tripCount.index();
823+
countNonUnitDim++;
824+
}
825+
}
826+
assert(countNonUnitDim == 1 &&
827+
"Expected only one non unit loop dim in this linalg op");
828+
return nonUnitDim;
829+
}
830+
831+
813832
/// Checks whether `val` can be used for calculating a loop invariant index.
814833
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
815834
VectorType resType) {
@@ -889,11 +908,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
889908
Operation *defOp = val.getDefiningOp();
890909
assert(defOp && "This is neither a block argument nor an operation result");
891910

892-
// Given the assumption on the loop ranges above, only the trailing loop
893-
// index is not constant.
894-
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
911+
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
912+
895913
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
896-
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
914+
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
897915
return true;
898916
}
899917

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
253253
transform.yield
254254
}
255255
}
256+
257+
// -----
258+
259+
#map = affine_map<(d0, d1) -> (d0, d1)>
260+
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
261+
func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
262+
%c0 = arith.constant 0 : index
263+
%0 = tensor.empty() : tensor<8x1xf32>
264+
%1 = linalg.generic {
265+
indexing_maps = [#map],
266+
iterator_types = ["parallel", "parallel"]
267+
} outs(%0 : tensor<8x1xf32>) {
268+
^bb0(%arg5: f32):
269+
%2 = linalg.index 0 : index
270+
%3 = linalg.index 1 : index
271+
%4 = affine.apply #map1(%arg1, %3, %arg1)
272+
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
273+
linalg.yield %extracted : f32
274+
} -> tensor<8x1xf32>
275+
return %1 : tensor<8x1xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
280+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
281+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
282+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
283+
transform.yield
284+
}
285+
}
286+
287+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
288+
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289+
// CHECK-SAME: %[[ARG1:.*]]: index
290+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
291+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
292+
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
293+
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
294+
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
295+
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
296+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
297+
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
298+
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
299+
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
300+
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
301+
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
302+
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
303+
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
304+
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
305+
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
306+
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
307+
256308
// -----
257309

258310
#map = affine_map<(d0) -> (d0)>

0 commit comments

Comments
 (0)