Skip to content

Commit 46bd65a

Browse files
authored
[mlir][LinAlg] Vectorize reverse-like ops using vector.gather ops. (#83205)
The reverse op is treated as a VectorMemoryAccessKind::Contiguous load. It is contiguous slice, but we'll need to compute indices differently and apply a reverse at vector level. It takes non-trivial efforts for the approach. The revision flips the case to use vector.gather. Otherwise there are functionality issues. E.g., the below example loaded `2, 3, 4` (which is a bug), but what we want is `2, 1, 0`. Before vectorization: ```mlir func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) { ^bb0(%out: f32): %1 = linalg.index 1 : index %2 = linalg.index 0 : index %3 = affine.apply #map1(%1, %2, %arg2) %4 = linalg.index 2 : index %5 = arith.subi %c2, %4 : index %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32> linalg.yield %extracted : f32 } -> tensor<1x1x3xf32> return %0 : tensor<1x1x3xf32> } ``` Partial IR after vectorization: ``` %5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1> %6 = vector.broadcast %arg0 : index to vector<1x1x4xindex> %7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex> %8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex> %9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32> ```
1 parent 9da9b5f commit 46bd65a

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,8 +891,7 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
891891

892892
// Conservatively reject Ops that could lead to indices with stride other
893893
// than 1.
894-
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
895-
ancestor))
894+
if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
896895
return false;
897896

898897
bool result = false;

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,48 @@ module attributes {transform.with_named_sequence} {
550550
transform.yield
551551
}
552552
}
553+
554+
// -----
555+
556+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
557+
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
558+
func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
559+
%c1 = arith.constant 1 : index
560+
%c0 = arith.constant 0 : index
561+
%c2 = arith.constant 2 : index
562+
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
563+
^bb0(%out: f32):
564+
%1 = linalg.index 1 : index
565+
%2 = linalg.index 0 : index
566+
%3 = affine.apply #map1(%1, %2, %arg2)
567+
%4 = linalg.index 2 : index
568+
%5 = arith.subi %c2, %4 : index
569+
%extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
570+
linalg.yield %extracted : f32
571+
} -> tensor<1x1x3xf32>
572+
return %0 : tensor<1x1x3xf32>
573+
}
574+
// CHECK-LABEL: func.func @vectorize_reverse_like_tensor_extract
575+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
576+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
577+
// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
578+
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
579+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
580+
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
581+
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
582+
// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
583+
// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
584+
// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
585+
// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
586+
// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
587+
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
588+
// CHECK: vector.transfer_write %[[GATHER]]
589+
590+
module attributes {transform.with_named_sequence} {
591+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
592+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
593+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
594+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
595+
transform.yield
596+
}
597+
}

0 commit comments

Comments
 (0)