Skip to content

Commit 51348ae

Browse files
committed
[mlir][Vector] Use a simpler lowering when emulating narrow type for vector.maskedload
arith.select should be used instead of a series of manual mask manipulating ops (arith.and/or/extsi)
1 parent 86afda0 commit 51348ae

File tree

2 files changed

+38
-148
lines changed

2 files changed

+38
-148
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,23 @@ struct ConvertVectorMaskedLoad final
135135
//
136136
// %mask = vector.constant_mask [3] : vector<6xi1>
137137
// %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru :
138-
// memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
138+
// memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4>
139139
//
140140
// can be replaced with
141141
//
142142
// %new_mask = vector.constant_mask [2] : vector<3xi1>
143-
// %new_pass_thru = vector.bitcast %pass_thru : vector<6xi4> to
144-
// vector<3xi8> %1 = vector.maskedload %0[%linear_index], %new_mask,
145-
// %new_pass_thru : memref<9xi8>, vector<3xi1>, vector<3xi8> into
146-
// vector<3xi8>
143+
// %new_pass_thru = vector.bitcast %pass_thru :
144+
// vector<6xi4> to vector<3xi8>
145+
// %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru :
146+
// memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8>
147+
// %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4>
147148
//
148149
// Since we are effectively loading 16 bits (2xi8) from the memref with the
149150
// new mask, while originally we only wanted to effectively load 12 bits
150151
// (3xi4) from the memref, we need to set the second half of the last i8
151-
// that was effectively loaded (i.e. the second i8) to 0.
152+
// that was effectively loaded (i.e. the second i8) to %pass_thru.
152153
//
153-
// %unset_mask = arith.extsi %mask : vector<6xi1> to vector<6xi4>
154-
// %2 = vector.bitcast %unset_mask : vector<6xi4> to vector<3xi8>
155-
// %3 = arith.andi %1, %2 : vector<3xi8>
156-
//
157-
// Then if the second half of the second i8 from %pass_thru is not all 0s,
158-
// we need to write their values back to the result.
159-
//
160-
// %cst_1 = arith.constant dense<-1> : vector<6xi4>
161-
// %set_mask = arith.xori %unset_mask, %cst_1 : vector<6xi4>
162-
// %4 = vector.bitcast %set_mask : vector<6xi4> to vector<3xi8>
163-
// %5 = arith.andi %new_pass_thru, %4 : vector<3xi8>
164-
//
165-
// %6 = arith.ori %3, %5 : vector<3xi8>
166-
// %7 = vector.bitcast %6 : vector<3xi8> to vector<6xi4>
154+
// %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4>
167155
//
168156
// Given these input values:
169157
// %mask = [1, 1, 1, 0, 0, 0]
@@ -177,17 +165,8 @@ struct ConvertVectorMaskedLoad final
177165
// %new_mask = [1, 1, 0]
178166
// %new_pass_thru = [0x78, 0x9A, 0xBC]
179167
// %1 = [0x12, 0x34, 0xBC]
180-
//
181-
// %unset_mask = [0xF, 0xF, 0xF, 0, 0, 0]
182-
// %2 = [0xFF, 0xF0, 0]
183-
// %3 = [0x12, 0x30, 0]
184-
//
185-
// %set_mask = [0, 0, 0, 0xF, 0xF, 0xF]
186-
// %4 = [0, 0x0F, 0xFF]
187-
// %5 = [0, 0x0A, 0xBC]
188-
//
189-
// %6 = [0x12, 0x3A, 0xBC]
190-
// %7 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
168+
// %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC]
169+
// %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC]
191170
//
192171
// TODO: Currently, only the even number of elements loading is supported.
193172
// To deal with the odd number of elements, one has to extract the
@@ -280,32 +259,13 @@ struct ConvertVectorMaskedLoad final
280259
newMask->getResult(0), newPassThru);
281260

282261
// Setting the part that originally was not effectively loaded from memory
283-
// to 0.
284-
auto andMask = rewriter.create<arith::ExtSIOp>(loc, origType, op.getMask());
285-
auto bitCastedAndMask =
286-
rewriter.create<vector::BitCastOp>(loc, newType, andMask);
287-
auto loadedFromMem =
288-
rewriter.create<arith::AndIOp>(loc, newLoad, bitCastedAndMask);
289-
290-
// Copying from pass through.
291-
auto allOne = rewriter.create<arith::ConstantOp>(
292-
loc, origType,
293-
DenseIntElementsAttr::get(origType, {APInt::getAllOnes(srcBits)}));
294-
auto passThruMask = rewriter.create<arith::XOrIOp>(loc, allOne.getResult(),
295-
andMask.getResult());
296-
auto bitCastedPassThruMask =
297-
rewriter.create<vector::BitCastOp>(loc, newType, passThruMask);
298-
auto copiedFromPassThru =
299-
rewriter.create<arith::AndIOp>(loc, newPassThru, bitCastedPassThruMask);
300-
301-
// Or-ing the first part loaded from memory and the second one copied from
302-
// pass through to form the result.
303-
auto result =
304-
rewriter.create<arith::OrIOp>(loc, loadedFromMem, copiedFromPassThru);
262+
// to pass through.
305263
auto bitCast =
306-
rewriter.create<vector::BitCastOp>(loc, op.getType(), result);
264+
rewriter.create<vector::BitCastOp>(loc, op.getType(), newLoad);
265+
auto select = rewriter.create<arith::SelectOp>(loc, op.getMask(), bitCast,
266+
op.getPassThru());
267+
rewriter.replaceOp(op, select->getResult(0));
307268

308-
rewriter.replaceOp(op, bitCast->getResult(0));
309269
return success();
310270
}
311271
};

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 23 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,9 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt
141141
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
142142
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
143143
// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
144-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
145-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
146-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
147-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
148-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
149-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
150-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
151-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
152-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
153-
// CHECK32: return %[[VEC_I4]]
144+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
145+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
146+
// CHECK32: return %[[SELECT]]
154147

155148
// -----
156149

@@ -176,15 +169,8 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
176169
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<4xi8>
177170
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
178171
// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
179-
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
180-
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
181-
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
182-
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
183-
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
184-
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
185-
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
186-
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
187-
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
172+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
173+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
188174

189175
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
190176
// CHECK32-DAG: #[[MASK_IDX_MAP:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
@@ -199,15 +185,8 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
199185
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<1xi32>
200186
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
201187
// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
202-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
203-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
204-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
205-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
206-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
207-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
208-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
209-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
210-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
188+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
189+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4>
211190

212191
// -----
213192

@@ -239,16 +218,9 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
239218
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32>
240219
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
241220
// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
242-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<4xi1> to vector<4xi8>
243-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<4xi8> to vector<1xi32>
244-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
245-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<4xi8>
246-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<4xi8>
247-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<4xi8> to vector<1xi32>
248-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
249-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
250-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<4xi8>
251-
// CHECK32: return %[[VEC_I4]]
221+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8>
222+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8>
223+
// CHECK32: return %[[SELECT]]
252224

253225
// -----
254226

@@ -272,36 +244,22 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
272244
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<4xi8>
273245
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
274246
// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
275-
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
276-
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<4xi8>
277-
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<4xi8>
278-
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
279-
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
280-
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<4xi8>
281-
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<4xi8>
282-
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<4xi8>
283-
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
247+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
248+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
284249

285250
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
286251
// CHECK32: func @vector_cst_maskedload_i4(
287252
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
288253
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
289254
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
290-
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
255+
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
291256
// CHECK32: %[[LD_IDX:.+]] = affine.apply #[[LOAD_IDX_MAP]]()[%[[ARG0]], %[[ARG1]]]
292257
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
293258
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<1xi32>
294259
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] :
295260
// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
296-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_MASK]] : vector<8xi1> to vector<8xi4>
297-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<8xi4> to vector<1xi32>
298-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<1xi32>
299-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<8xi4>
300-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<8xi4>
301-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<8xi4> to vector<1xi32>
302-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
303-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
304-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
261+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
262+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
305263

306264
// -----
307265

@@ -331,15 +289,8 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
331289
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
332290
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
333291
// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
334-
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
335-
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
336-
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
337-
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
338-
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
339-
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
340-
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
341-
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
342-
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
292+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
293+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
343294

344295
// CHECK32: func @vector_extract_maskedload_i4(
345296
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
@@ -353,15 +304,8 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
353304
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
354305
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
355306
// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
356-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
357-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
358-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
359-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
360-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
361-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
362-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
363-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
364-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
307+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
308+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
365309

366310
// -----
367311

@@ -389,15 +333,8 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
389333
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
390334
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
391335
// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
392-
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
393-
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
394-
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
395-
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
396-
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
397-
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
398-
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
399-
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
400-
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
336+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4>
337+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
401338

402339
// CHECK32: func @vector_extract_cst_maskedload_i4(
403340
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
@@ -411,12 +348,5 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
411348
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
412349
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
413350
// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
414-
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
415-
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
416-
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
417-
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
418-
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
419-
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
420-
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
421-
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
422-
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
351+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
352+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>

0 commit comments

Comments
 (0)