Skip to content

Commit ef8985a

Browse files
committed
[mlir][linalg] Upgrade vectorisation of tensor.extract
This PR removes the assumption that reading from a dynamic tensor is always a gather load: ```mlir %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> ``` That assumption was originally introduced to simplify the implementation and to reduce the number of cases to consider. Now that the vectorisation of `tensor.extract` has been around for > 1 year and has been quite stable, we can safely relax it. This is a relatively small change - rather than using the parent linalg Op to infer the target output shape (not possible with dynamic shapes), the vectorizer will use the (previously constructed) output vector shape instead. As expected, the following test required updating (`vector.gather` -> `vector.transfer_read`): * @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test for scalable vectors is also added.
1 parent 2ba3fe7 commit ef8985a

File tree

3 files changed

+126
-74
lines changed

3 files changed

+126
-74
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -808,14 +808,14 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
808808

809809
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
810810

811-
/// Checks whether /p val can be used for calculating a loop invariant index.
812-
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
811+
/// Checks whether `val` can be used for calculating a loop invariant index.
812+
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
813+
VectorType resType) {
813814

814-
auto targetShape = linalgOp.getStaticLoopRanges();
815-
assert(((llvm::count_if(targetShape,
815+
assert(((llvm::count_if(resType.getShape(),
816816
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
817817
"n-D vectors are not yet supported");
818-
assert(targetShape.back() != 1 &&
818+
assert(resType.getShape().back() != 1 &&
819819
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
820820

821821
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +849,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
849849

850850
bool result = true;
851851
for (auto op : ancestor->getOperands())
852-
result &= isLoopInvariantIdx(linalgOp, op);
852+
result &= isLoopInvariantIdx(linalgOp, op, resType);
853853

854854
return result;
855855
}
@@ -871,13 +871,12 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
871871
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
872872
/// updated to `true` when such an op is found.
873873
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
874-
bool &foundIndexOp) {
874+
bool &foundIndexOp, VectorType resType) {
875875

876-
auto targetShape = linalgOp.getStaticLoopRanges();
877-
assert(((llvm::count_if(targetShape,
876+
assert(((llvm::count_if(resType.getShape(),
878877
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
879878
"n-D vectors are not yet supported");
880-
assert(targetShape.back() != 1 &&
879+
assert(resType.getShape().back() != 1 &&
881880
"1-D vectors with the trailing dim 1 are not yet supported");
882881

883882
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -912,46 +911,41 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
912911

913912
bool result = false;
914913
for (auto op : ancestor->getOperands())
915-
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
914+
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
916915

917916
return result;
918917
}
919918

920919
/// Infer the memory access pattern for the input ExtractOp
921920
///
922-
/// Based on the operation shapes and indices (usually based on the iteration
923-
/// space of the parent `linalgOp` operation), decides whether the input
924-
/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
925-
/// gather load.
921+
/// Based on the ExtratOp result shape and the access indices, decides whether
922+
/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
923+
/// or a gather load. When analysing the ExtractOp indices (to identify
924+
/// contiguous laods), this method looks for "loop" invariant indices (e.g.
925+
/// block arguments) and indices that change linearly (e.g. via `linalg.index`
926+
/// Op).
926927
///
927928
/// Note that it is always safe to use gather load operations for contiguous
928929
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
929930
/// that `extractOp` is a gather load.
930931
static VectorMemoryAccessKind
931932
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
932-
LinalgOp &linalgOp) {
933+
LinalgOp &linalgOp, VectorType resType) {
933934

934-
auto targetShape = linalgOp.getStaticLoopRanges();
935935
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
936936

937-
// 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
937+
// 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
938938
if (inputShape.getShape().empty())
939939
return VectorMemoryAccessKind::ScalarBroadcast;
940940

941-
// 0.2 In the case of dynamic shapes just bail-out and assume that it's a
942-
// gather load.
943-
// TODO: Relax this condition.
944-
if (linalgOp.hasDynamicShape())
945-
return VectorMemoryAccessKind::Gather;
946-
947941
// 1. Assume that it's a gather load when reading _into_:
948-
// * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
949-
// * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
942+
// * an n-D "vector", like `vector<1x2x4xi32` or `vector<2x1x4xi32>`, or
943+
// * a 1-D "vector" with the trailing dim equal 1, e.g.
944+
// `vector<1x4x1xi32>`.
950945
// TODO: Relax these conditions.
951-
// FIXME: This condition assumes non-dynamic sizes.
952-
if ((llvm::count_if(targetShape,
946+
if ((llvm::count_if(resType.getShape(),
953947
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
954-
targetShape.back() == 1)
948+
resType.getShape().back() == 1)
955949
return VectorMemoryAccessKind::Gather;
956950

957951
// 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -972,7 +966,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
972966
if (inputShape.getShape()[i] == 1)
973967
continue;
974968

975-
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
969+
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
976970
}
977971

978972
if (!leadingIdxsLoopInvariant) {
@@ -989,7 +983,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
989983
// 4a. Scalar broadcast load
990984
// If the trailing index is loop invariant then this is a scalar load.
991985
if (leadingIdxsLoopInvariant &&
992-
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
986+
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
993987
LDBG("Found scalar broadcast load: " << extractOp);
994988

995989
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -1000,8 +994,8 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1000994
// This effectively means that it must be based on the trailing loop index.
1001995
// This is what the following bool captures.
1002996
bool foundIndexOp = false;
1003-
bool isContiguousLoad =
1004-
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
997+
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
998+
foundIndexOp, resType);
1005999
isContiguousLoad &= foundIndexOp;
10061000

10071001
if (isContiguousLoad) {
@@ -1042,7 +1036,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10421036
rewriter.create<arith::ConstantIndexOp>(loc, 0));
10431037

10441038
VectorMemoryAccessKind memAccessKind =
1045-
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1039+
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
10461040

10471041
// 1. Handle gather access
10481042
if (memAccessKind == VectorMemoryAccessKind::Gather) {

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
162162

163163
// CHECK-LABEL: @vectorize_linalg_index
164164
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
165-
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
166-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
167165
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
168166
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
169167
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
170168
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
171-
// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
169+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
172170
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
173-
// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
174-
// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
175-
// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
171+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
172+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
176173
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
177174

178175
module attributes {transform.with_named_sequence} {

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
120120

121121
// -----
122122

123-
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
123+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
124+
%src: tensor<?x?xf32>,
125+
%output : tensor<?x?xf32>,
126+
%idx: index) -> tensor<?x?xf32> {
127+
124128
%c79 = arith.constant 79 : index
125129
%1 = linalg.generic {
126130
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
127131
iterator_types = ["parallel", "parallel"]
128-
} outs(%extracted_slice : tensor<?x?xf32>) {
132+
} outs(%output : tensor<?x?xf32>) {
129133
^bb0(%out: f32):
130134
%2 = linalg.index 1 : index
131-
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
132-
%extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
135+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
136+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
133137
linalg.yield %extracted : f32
134138
} -> tensor<?x?xf32>
135139
return %1 : tensor<?x?xf32>
136140
}
137141

138142
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
139-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
140-
// CHECK-SAME: %[[VAL_1:.*]]: index,
141-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
142-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
143-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
144-
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
145-
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
146-
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
147-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
148-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
149-
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
150-
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
151-
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
152-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
153-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
154-
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
155-
// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
156-
// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
157-
// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
158-
// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
159-
// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
160-
// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
161-
// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
162-
// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
163-
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
164-
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
165-
// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
166-
// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
167-
// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
168-
// CHECK: }
143+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
144+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
145+
// CHECK-SAME: %[[IDX:.*]]: index)
146+
147+
/// Create the mask
148+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
149+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
150+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
151+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
152+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
153+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
154+
155+
/// TODO: This transfer_read is redundant - remove
156+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
157+
158+
/// Caluclate the index vector
159+
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
160+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
161+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
162+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
163+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
164+
165+
/// Extract the starting point from the index vector
166+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
167+
168+
// Final read and write
169+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
170+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
169171

170172
module attributes {transform.with_named_sequence} {
171173
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
177179

178180
// -----
179181

182+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
183+
%src: tensor<?x?xf32>,
184+
%output : tensor<?x?xf32>,
185+
%idx: index) -> tensor<?x?xf32> {
186+
187+
%c79 = arith.constant 79 : index
188+
%1 = linalg.generic {
189+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
190+
iterator_types = ["parallel", "parallel"]
191+
} outs(%output : tensor<?x?xf32>) {
192+
^bb0(%out: f32):
193+
%2 = linalg.index 1 : index
194+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
195+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
196+
linalg.yield %extracted : f32
197+
} -> tensor<?x?xf32>
198+
return %1 : tensor<?x?xf32>
199+
}
200+
201+
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
202+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
203+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
204+
// CHECK-SAME: %[[IDX:.*]]: index)
205+
206+
/// Create the mask
207+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
208+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
209+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
210+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
211+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
212+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
213+
214+
/// TODO: This transfer_read is redundant - remove
215+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
216+
217+
/// Caluclate the index vector
218+
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
219+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
220+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
221+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
222+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
223+
224+
/// Extract the starting point from the index vector
225+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
226+
227+
// Final read and write
228+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
229+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
230+
231+
module attributes {transform.with_named_sequence} {
232+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
233+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
234+
transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
235+
transform.yield
236+
}
237+
}
238+
239+
// -----
240+
180241
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
181242
%c16 = arith.constant 16 : index
182243
%1 = linalg.generic {

0 commit comments

Comments
 (0)