Skip to content

Commit 8e66303

Browse files
authored
[mlir][Vector] Remove trivial uses of vector.extractelement/vector.insertelement (1/N) (#116053)
This patch removes trivial usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information. Further patches will remove more usages of these ops.
1 parent cd88bfc commit 8e66303

11 files changed

+102
-160
lines changed

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
313313
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314314
Value asF16s =
315315
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316-
Value result = rewriter.create<vector::ExtractElementOp>(
317-
loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
316+
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
318317
return rewriter.replaceOp(op, result);
319318
}
320319
VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
334333
for (int64_t i = 0; i < numElements; i += 2) {
335334
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
336335
Value thisResult = nullptr;
337-
Value elemA = rewriter.create<vector::ExtractElementOp>(
338-
loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
336+
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
339337
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
340338

341339
if (elemsThisOp == 2) {
342-
elemB = rewriter.create<vector::ExtractElementOp>(
343-
loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
340+
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
344341
}
345342

346343
thisResult =

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11341134
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
11351135
// (0th) element and use that.
11361136
SmallVector<Value> transferReadIdxs;
1137-
auto zero = rewriter.create<arith::ConstantOp>(
1138-
loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
11391137
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
11401138
Value idx = bvm.lookup(extractOp.getIndices()[i]);
11411139
if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11491147
resultType.getScalableDims().back()),
11501148
idx);
11511149
transferReadIdxs.push_back(
1152-
rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
1150+
rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
11531151
}
11541152

11551153
// `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14151413
// 3.c. Not all ops support 0-d vectors, extract the scalar for now.
14161414
// TODO: remove this.
14171415
if (readType.getRank() == 0)
1418-
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
1416+
readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
1417+
ArrayRef<int64_t>());
14191418

14201419
LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
14211420
<< "\n");
@@ -2273,7 +2272,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
22732272
loc, readType, copyOp.getSource(), indices,
22742273
rewriter.getMultiDimIdentityMap(srcType.getRank()));
22752274
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
2276-
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
2275+
readValue =
2276+
rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>());
22772277
readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
22782278
}
22792279
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(

mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
391391
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
392392
}
393393

394-
result = rewriter.create<vector::InsertElementOp>(
395-
loc, reductionOp->getResult(0), result,
396-
rewriter.create<arith::ConstantIndexOp>(loc, i));
394+
result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
395+
result, i);
397396
}
398397

399398
rewriter.replaceOp(rootOp, result);

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,6 @@
1818
using namespace mlir;
1919
using namespace mlir::vector;
2020

21-
// Helper that picks the proper sequence for inserting.
22-
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
23-
Value into, int64_t offset) {
24-
auto vectorType = cast<VectorType>(into.getType());
25-
if (vectorType.getRank() > 1)
26-
return rewriter.create<InsertOp>(loc, from, into, offset);
27-
return rewriter.create<vector::InsertElementOp>(
28-
loc, vectorType, from, into,
29-
rewriter.create<arith::ConstantIndexOp>(loc, offset));
30-
}
31-
32-
// Helper that picks the proper sequence for extracting.
33-
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
34-
int64_t offset) {
35-
auto vectorType = cast<VectorType>(vector.getType());
36-
if (vectorType.getRank() > 1)
37-
return rewriter.create<ExtractOp>(loc, vector, offset);
38-
return rewriter.create<vector::ExtractElementOp>(
39-
loc, vectorType.getElementType(), vector,
40-
rewriter.create<arith::ConstantIndexOp>(loc, offset));
41-
}
42-
4321
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
4422
/// have different ranks.
4523
///
@@ -173,11 +151,13 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
173151
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
174152
off += stride, ++idx) {
175153
// 1. extract the proper subvector (or element) from source
176-
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
154+
Value extractedSource =
155+
rewriter.create<ExtractOp>(loc, op.getSource(), idx);
177156
if (isa<VectorType>(extractedSource.getType())) {
178157
// 2. If we have a vector, extract the proper subvector from destination
179158
// Otherwise we are at the element level and no need to recurse.
180-
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
159+
Value extractedDest =
160+
rewriter.create<ExtractOp>(loc, op.getDest(), off);
181161
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
182162
// smaller rank.
183163
extractedSource = rewriter.create<InsertStridedSliceOp>(
@@ -186,7 +166,7 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
186166
getI64SubArray(op.getStrides(), /* dropFront=*/1));
187167
}
188168
// 4. Insert the extractedSource into the res vector.
189-
res = insertOne(rewriter, loc, extractedSource, res, off);
169+
res = rewriter.create<InsertOp>(loc, extractedSource, res, off);
190170
}
191171

192172
rewriter.replaceOp(op, res);
@@ -277,8 +257,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
277257
};
278258

279259
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
280-
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
281-
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
260+
/// For such cases, we can rewrite it to ExtractOp + lower rank
261+
/// ExtractStridedSliceOp + InsertOp for the n-D case.
282262
class DecomposeNDExtractStridedSlice
283263
: public OpRewritePattern<ExtractStridedSliceOp> {
284264
public:
@@ -317,12 +297,12 @@ class DecomposeNDExtractStridedSlice
317297
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
318298
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
319299
off += stride, ++idx) {
320-
Value one = extractOne(rewriter, loc, op.getVector(), off);
300+
Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off);
321301
Value extracted = rewriter.create<ExtractStridedSliceOp>(
322302
loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
323303
getI64SubArray(op.getSizes(), /* dropFront=*/1),
324304
getI64SubArray(op.getStrides(), /* dropFront=*/1));
325-
res = insertOne(rewriter, loc, extracted, res, idx);
305+
res = rewriter.create<InsertOp>(loc, extracted, res, idx);
326306
}
327307
rewriter.replaceOp(op, res);
328308
return success();

mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
func.func @scalar_trunc(%v: f32) -> f16{
66
// CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
77
// CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
8-
// CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
8+
// CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
99
// CHECK: return %[[extract]] : f16
1010
%w = arith.truncf %v : f32 to f16
1111
return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
1414
// CHECK-LABEL: @vector_trunc
1515
// CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
1616
func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
17-
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
18-
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
17+
// CHECK: %[[elem0:.*]] = vector.extract %[[value]]
18+
// CHECK: %[[elem1:.*]] = vector.extract %[[value]]
1919
// CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
2020
// CHECK: return %[[ret]]
2121
%w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
2525
// CHECK-LABEL: @vector_trunc_long
2626
// CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
2727
func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
28-
// CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
29-
// CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
28+
// CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
29+
// CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
3030
// CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
3131
// CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
32-
// CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
33-
// CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
32+
// CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
33+
// CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
3434
// CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
3535
// CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<9xf16>
36-
// CHECK: %[[elem4:.*]] = vector.extractelement %[[value]][%c4 : index]
37-
// CHECK: %[[elem5:.*]] = vector.extractelement %[[value]][%c5 : index]
36+
// CHECK: %[[elem4:.*]] = vector.extract %[[value]][4]
37+
// CHECK: %[[elem5:.*]] = vector.extract %[[value]][5]
3838
// CHECK: %[[packed2:.*]] = rocdl.cvt.pkrtz %[[elem4]], %[[elem5]] : vector<2xf16>
3939
// CHECK: %[[out2:.*]] = vector.insert_strided_slice %[[packed2]], %[[out1]] {offsets = [4], strides = [1]} : vector<2xf16> into vector<9xf16>
40-
// CHECK: %[[elem6:.*]] = vector.extractelement %[[value]]
41-
// CHECK: %[[elem7:.*]] = vector.extractelement %[[value]]
40+
// CHECK: %[[elem6:.*]] = vector.extract %[[value]]
41+
// CHECK: %[[elem7:.*]] = vector.extract %[[value]]
4242
// CHECK: %[[packed3:.*]] = rocdl.cvt.pkrtz %[[elem6]], %[[elem7]] : vector<2xf16>
4343
// CHECK: %[[out3:.*]] = vector.insert_strided_slice %[[packed3]], %[[out2]] {offsets = [6], strides = [1]} : vector<2xf16> into vector<9xf16>
44-
// CHECK: %[[elem8:.*]] = vector.extractelement %[[value]]
44+
// CHECK: %[[elem8:.*]] = vector.extract %[[value]]
4545
// CHECK: %[[packed4:.*]] = rocdl.cvt.pkrtz %[[elem8:.*]] : vector<2xf16>
4646
// CHECK: %[[slice:.*]] = vector.extract_strided_slice %[[packed4]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf16> to vector<1xf16>
4747
// CHECK: %[[out4:.*]] = vector.insert_strided_slice %[[slice]], %[[out3]] {offsets = [8], strides = [1]} : vector<1xf16> into vector<9xf16>

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func.func @vectorize_linalg_index(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) ->
164164
// CHECK: %[[DST_DIM0:.*]] = tensor.dim %[[DST]], %[[C0]] : tensor<?xf32>
165165
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DST_DIM0]] : vector<[4]xi1>
166166
// CHECK-DAG: %[[STEP:.+]] = vector.step : vector<[4]xindex>
167-
// CHECK-DAG: %[[STEP_ELEMENT:.+]] = vector.extractelement %[[STEP]][%c0_i32 : i32] : vector<[4]xindex>
167+
// CHECK-DAG: %[[STEP_ELEMENT:.+]] = vector.extract %[[STEP]][0] : index from vector<[4]xindex>
168168

169169
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP_ELEMENT]]], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
170170
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
@@ -207,7 +207,7 @@ func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
207207
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
208208
// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
209209
// CHECK: %[[VEC_RD_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[C0_F32]] : tensor<f32>, vector<f32>
210-
// CHECK: %[[ACC_f32:.*]] = vector.extractelement %[[VEC_RD_1]][] : vector<f32>
210+
// CHECK: %[[ACC_f32:.*]] = vector.extract %[[VEC_RD_1]][] : f32 from vector<f32>
211211
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[ACC_f32]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
212212
// CHECK: %[[VEC_f32:.*]] = vector.broadcast %[[REDUCE]] : f32 to vector<f32>
213213
// CHECK: %{{.*}} = vector.transfer_write %[[VEC_f32]], %[[ARG_1]][] : vector<f32>, tensor<f32>

mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ module attributes {transform.with_named_sequence} {
414414
func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
415415
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
416416
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
417-
// CHECK: %[[val:.*]] = vector.extractelement %[[V]][] : vector<f32>
417+
// CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
418418
// CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
419419
// CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector<f32>, memref<f32>
420420
memref.copy %A, %B : memref<f32> to memref<f32>
@@ -1440,7 +1440,6 @@ module attributes {transform.with_named_sequence} {
14401440
// CHECK-LABEL: func @reduce_1d(
14411441
// CHECK-SAME: %[[A:.*]]: tensor<32xf32>
14421442
func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
1443-
// CHECK-DAG: %[[vF0:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
14441443
// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32
14451444
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
14461445
%f0 = arith.constant 0.000000e+00 : f32
@@ -1451,8 +1450,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
14511450
%1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
14521451
// CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
14531452
// CHECK-SAME: : tensor<32xf32>, vector<32xf32>
1454-
// CHECK: %[[f0:.*]] = vector.extractelement %[[vF0]][] : vector<f32>
1455-
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
1453+
// CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0]
14561454
// CHECK-SAME: : vector<32xf32> to f32
14571455
// CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
14581456
// CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
@@ -1779,9 +1777,9 @@ module attributes {transform.with_named_sequence} {
17791777

17801778
// CHECK-LABEL: func @zero_dim_tensor
17811779
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
1782-
// CHECK: vector.extractelement
1780+
// CHECK: vector.extract
17831781
// CHECK: vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
1784-
// CHECK: vector.extractelement
1782+
// CHECK: vector.extract
17851783
// CHECK: arith.addf {{.*}} : f32
17861784
// CHECK: vector.broadcast %{{.*}} : f32 to vector<f32>
17871785
// CHECK: vector.transfer_write {{.*}} : vector<f32>, tensor<f32>

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
3737
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
3838
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<4xindex>
3939
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
40-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
4140
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
4241

4342
/// Extract the starting point from the index vector
44-
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
43+
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>
4544

4645
// Final read and write
4746
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
@@ -98,11 +97,10 @@ func.func @masked_static_vectorize_nd_tensor_extract_with_affine_apply_contiguou
9897
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
9998
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX_IN]] : index to vector<[4]xindex>
10099
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
101-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
102100
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
103101

104102
/// Extract the starting point from the index vector
105-
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
103+
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>
106104

107105
// Final read and write
108106
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
@@ -159,11 +157,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
159157
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
160158
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
161159
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
162-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
163160
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
164161

165162
/// Extract the starting point from the index vector
166-
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
163+
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<4xindex>
167164

168165
// Final read and write
169166
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
@@ -218,11 +215,10 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
218215
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
219216
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
220217
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
221-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
222218
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
223219

224220
/// Extract the starting point from the index vector
225-
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
221+
// CHECK: %[[IDX_START:.*]] = vector.extract %[[SC]][0] : index from vector<[4]xindex>
226222

227223
// Final read and write
228224
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>

0 commit comments

Comments
 (0)