Skip to content

Commit 86afda0

Browse files
committed
[mlir][Vector] Handle narrow type emulation of vector.maskedload when mask is an extraction
1 parent 2acafc1 commit 86afda0

File tree

2 files changed

+147
-7
lines changed

2 files changed

+147
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,28 @@ struct ConvertVectorMaskedLoad final
213213
auto numElements = (origElements + scale - 1) / scale;
214214
auto newType = VectorType::get(numElements, newElementType);
215215

216-
auto createMaskOp = op.getMask().getDefiningOp<vector::CreateMaskOp>();
217-
auto constantMaskOp = op.getMask().getDefiningOp<vector::ConstantMaskOp>();
218-
// TODO: Handle extracted mask.
216+
auto maskOp = op.getMask().getDefiningOp();
217+
SmallVector<vector::ExtractOp, 2> extractOps;
218+
// Finding the mask creation operation.
219+
while (maskOp &&
220+
!isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
221+
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
222+
maskOp = extractOp.getVector().getDefiningOp();
223+
extractOps.push_back(extractOp);
224+
}
225+
}
226+
auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
227+
auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
219228
if (!createMaskOp && !constantMaskOp)
220229
return failure();
221230

222231
// Computing the "compressed" mask. All the emulation logic (i.e. computing
223232
// new mask index) only happens on the last dimension of the vectors.
224233
Operation *newMask = nullptr;
225-
auto newMaskType = VectorType::get(numElements, rewriter.getI1Type());
234+
auto shape = llvm::to_vector(
235+
maskOp->getResultTypes()[0].cast<VectorType>().getShape().drop_back());
236+
shape.push_back(numElements);
237+
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
226238
if (createMaskOp) {
227239
auto maskOperands = createMaskOp.getOperands();
228240
auto numMaskOperands = maskOperands.size();
@@ -234,18 +246,28 @@ struct ConvertVectorMaskedLoad final
234246
getAsOpFoldResult(maskOperands[numMaskOperands - 1]);
235247
OpFoldResult maskIndex =
236248
affine::makeComposedFoldedAffineApply(rewriter, loc, s0, origIndex);
237-
newMask = rewriter.create<vector::CreateMaskOp>(
238-
loc, newMaskType,
249+
auto newMaskOperands = llvm::to_vector(maskOperands.drop_back());
250+
newMaskOperands.push_back(
239251
getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex));
252+
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
253+
newMaskOperands);
240254
} else if (constantMaskOp) {
241255
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
242256
auto numMaskOperands = maskDimSizes.size();
243257
auto origIndex =
244258
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
245259
auto maskIndex =
246260
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
261+
auto newMaskDimSizes = llvm::to_vector(maskDimSizes.drop_back());
262+
newMaskDimSizes.push_back(maskIndex);
247263
newMask = rewriter.create<vector::ConstantMaskOp>(
248-
loc, newMaskType, ArrayAttr::get(op.getContext(), maskIndex));
264+
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
265+
}
266+
267+
while (!extractOps.empty()) {
268+
newMask = rewriter.create<vector::ExtractOp>(
269+
loc, newMask->getResults()[0], extractOps.back().getMixedPosition());
270+
extractOps.pop_back();
249271
}
250272

251273
auto newPassThru =

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,121 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
302302
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<1xi32>
303303
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<1xi32>
304304
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>
305+
306+
// -----
307+
308+
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
309+
%0 = memref.alloc() : memref<8x8x16xi4>
310+
%c0 = arith.constant 0 : index
311+
%c16 = arith.constant 16 : index
312+
%c8 = arith.constant 8 : index
313+
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
314+
%cst_2 = arith.constant dense<0> : vector<16xi4>
315+
%27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
316+
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
317+
%49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
318+
%50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
319+
%63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
320+
return %63 : vector<8x8x16xi4>
321+
}
322+
// CHECK: func @vector_extract_maskedload_i4(
323+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
324+
// CHECK: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
325+
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
326+
// CHECK: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
327+
// CHECK: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
328+
// CHECK: %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x8xi1>
329+
// CHECK: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
330+
// CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
331+
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
332+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
333+
// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
334+
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
335+
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
336+
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
337+
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
338+
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
339+
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
340+
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
341+
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
342+
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
343+
344+
// CHECK32: func @vector_extract_maskedload_i4(
345+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
346+
// CHECK32: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
347+
// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x16xi1>
348+
// CHECK32: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
349+
// CHECK32: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
350+
// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask {{.*}} vector<8x8x2xi1>
351+
// CHECK32: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
352+
// CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
353+
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
354+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
355+
// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
356+
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
357+
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
358+
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
359+
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
360+
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
361+
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
362+
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
363+
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
364+
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>
365+
366+
// -----
367+
368+
func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
369+
%0 = memref.alloc() : memref<8x8x16xi4>
370+
%c0 = arith.constant 0 : index
371+
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
372+
%cst_2 = arith.constant dense<0> : vector<16xi4>
373+
%27 = vector.constant_mask [8, 4, 16] : vector<8x8x16xi1>
374+
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
375+
%49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
376+
%50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
377+
%63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
378+
return %63 : vector<8x8x16xi4>
379+
}
380+
// CHECK: func @vector_extract_cst_maskedload_i4(
381+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<512xi8>
382+
// CHECK: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
383+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
384+
// CHECK: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
385+
// CHECK: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
386+
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x8xi1>
387+
// CHECK: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x8xi1>
388+
// CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1>
389+
// CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8>
390+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
391+
// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8>
392+
// CHECK: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
393+
// CHECK: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<8xi8>
394+
// CHECK: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<8xi8>
395+
// CHECK: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
396+
// CHECK: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
397+
// CHECK: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<8xi8>
398+
// CHECK: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<8xi8>
399+
// CHECK: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<8xi8>
400+
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<8xi8> to vector<16xi4>
401+
402+
// CHECK32: func @vector_extract_cst_maskedload_i4(
403+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<128xi32>
404+
// CHECK32: %[[PASSTHRU:.+]] = arith.constant dense<0> : vector<16xi4>
405+
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x16xi1>
406+
// CHECK32: %[[ORIG_EXT1:.+]] = vector.extract %[[ORIG_MASK]][0] : vector<8x16xi1>
407+
// CHECK32: %[[ORIG_EXT2:.+]] = vector.extract %[[ORIG_EXT1]][0] : vector<16xi1>
408+
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask {{.*}} vector<8x8x2xi1>
409+
// CHECK32: %[[NEW_EXT1:.+]] = vector.extract %[[NEW_MASK]][0] : vector<8x2xi1>
410+
// CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1>
411+
// CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32>
412+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] :
413+
// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
414+
// CHECK32: %[[EXT:.+]] = arith.extsi %[[ORIG_EXT2]] : vector<16xi1> to vector<16xi4>
415+
// CHECK32: %[[AND_MASK:.+]] = vector.bitcast %[[EXT]] : vector<16xi4> to vector<2xi32>
416+
// CHECK32: %[[FIRST_PART:.+]] = arith.andi %[[LOAD]], %[[AND_MASK]] : vector<2xi32>
417+
// CHECK32: %[[ONES:.+]] = arith.constant dense<-1> : vector<16xi4>
418+
// CHECK32: %[[XOR:.+]] = arith.xori %[[ONES]], %[[EXT]] : vector<16xi4>
419+
// CHECK32: %[[PASSTHRU_MASK:.+]] = vector.bitcast %[[XOR]] : vector<16xi4> to vector<2xi32>
420+
// CHECK32: %[[SECOND_PART:.+]] = arith.andi %[[NEW_PASSTHRU]], %[[PASSTHRU_MASK]] : vector<2xi32>
421+
// CHECK32: %[[VEC:.+]] = arith.ori %[[FIRST_PART]], %[[SECOND_PART]] : vector<2xi32>
422+
// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<2xi32> to vector<16xi4>

0 commit comments

Comments
 (0)