Skip to content

Commit 20b886a

Browse files
committed
[mlir][linalg] Relax tensor.extract vectorization
Simplifies the vectorization of tensor.extract so that: * all cases that read into a genuinely multi-dim vector (*) are considered a gather load, * all other cases are considered as potential contiguous loads. This change means that the following extraction from a "column" tensor is correctly identified as a scalar load followed by a broadcast (rather than a gather load). ```mlir func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %cst = arith.constant dense<[...]> : tensor<15x1xi32> %out = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) { ^bb0(%out: i32): %8 = linalg.index 0 : index %idx_0 = linalg.index 0 : index %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32> linalg.yield %extracted : i32 } -> tensor<1x1x4xi32> return %out:tensor<1x1x4xi32> } ``` (*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
1 parent 2ba3fe7 commit 20b886a

File tree

2 files changed

+71
-20
lines changed

2 files changed

+71
-20
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -944,27 +944,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
944944
if (linalgOp.hasDynamicShape())
945945
return VectorMemoryAccessKind::Gather;
946946

947-
// 1. Assume that it's a gather load when reading _into_:
948-
// * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
949-
// * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
950-
// TODO: Relax these conditions.
951-
// FIXME: This condition assumes non-dynamic sizes.
952-
if ((llvm::count_if(targetShape,
953-
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
954-
targetShape.back() == 1)
955-
return VectorMemoryAccessKind::Gather;
956-
957-
// 2. Assume that it's a gather load when reading _from_ a tensor for which
958-
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
959-
// TODO: Relax this condition.
960-
if (inputShape.getShape().back() == 1)
947+
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
948+
// otherwise.
949+
bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
950+
return dimSize > 1;
951+
}) == 1);
952+
953+
// 1. Assume that it's a gather load when reading non-1D vector.
954+
if (!isOutput1DVector)
961955
return VectorMemoryAccessKind::Gather;
962956

963957
bool leadingIdxsLoopInvariant = true;
964958

965-
// 3. Analyze the leading indices of `extractOp`.
959+
// 2. Analyze the leading indices of `extractOp`.
966960
// Look at the way each index is calculated and decide whether it is suitable
967-
// for a contiguous load, i.e. whether it's loop invariant.
961+
// for a contiguous load, i.e. whether it's loop invariant. If not, it's a
962+
// gather load.
968963
auto indices = extractOp.getIndices();
969964
auto leadIndices = indices.drop_back(1);
970965

@@ -980,13 +975,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
980975
return VectorMemoryAccessKind::Gather;
981976
}
982977

983-
// 4. Analyze the trailing index for `extractOp`.
978+
// 3. Analyze the trailing index for `extractOp`.
984979
// At this point we know that the leading indices are loop invariant. This
985980
// means that is potentially a scalar or a contiguous load. We can decide
986981
// based on the trailing idx.
987982
auto extractOpTrailingIdx = indices.back();
988983

989-
// 4a. Scalar broadcast load
984+
// 3a. Scalar broadcast load
990985
// If the trailing index is loop invariant then this is a scalar load.
991986
if (leadingIdxsLoopInvariant &&
992987
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -995,7 +990,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
995990
return VectorMemoryAccessKind::ScalarBroadcast;
996991
}
997992

998-
// 4b. Contiguous loads
993+
// 3b. Contiguous loads
999994
// The trailing `extractOp` index should increment with every loop iteration.
1000995
// This effectively means that it must be based on the trailing loop index.
1001996
// This is what the following bool captures.
@@ -1009,7 +1004,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10091004
return VectorMemoryAccessKind::Contiguous;
10101005
}
10111006

1012-
// 5. Fallback case - gather load.
1007+
// 4. Fallback case - gather load.
10131008
LDBG("Found gather load: " << extractOp);
10141009
return VectorMemoryAccessKind::Gather;
10151010
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,59 @@ module attributes {transform.with_named_sequence} {
595595
transform.yield
596596
}
597597
}
598+
599+
600+
// -----
601+
602+
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
603+
%c4 = arith.constant 4 : index
604+
%c0 = arith.constant 0 : index
605+
%cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
606+
607+
%out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
608+
^bb0(%out: i32):
609+
%8 = linalg.index 0 : index
610+
%idx_0 = linalg.index 0 : index
611+
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
612+
linalg.yield %extracted : i32
613+
} -> tensor<1x1x4xi32>
614+
615+
return %out:tensor<1x1x4xi32>
616+
}
617+
618+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
619+
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
620+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
621+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
622+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
623+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
624+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
625+
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
626+
// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
627+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
628+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
629+
// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
630+
// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
631+
// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
632+
// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
633+
// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
634+
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
635+
// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
636+
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
637+
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
638+
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
639+
// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
640+
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
641+
// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
642+
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
643+
// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
644+
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
645+
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
646+
647+
module attributes {transform.with_named_sequence} {
648+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
649+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
650+
transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
651+
transform.yield
652+
}
653+
}

0 commit comments

Comments
 (0)