Skip to content

Commit eef52e2

Browse files
committed
fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation
* Fix failing test * Tweak/fix the comment * Rename: @vector_cst_maskedload_i8 -> @vector_cst_maskedload_i8_constant_mask (same for other similar tests)
1 parent 8a9abf6 commit eef52e2

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,30 +323,34 @@ struct ConvertVectorMaskedStore final
323323
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
324324

325325
// Load the whole data and use arith.select to handle the corner cases.
326-
// E.g., given these input i4 values:
327326
//
328-
// %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
327+
// As an example, for this masked store:
329328
//
330-
// %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1)
329+
// vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
330+
//
331+
// and given these input i4 values:
332+
//
333+
// %mask = [1, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
331334
// %0[%c0, %c0] =
332335
// [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
333336
// %val_to_store =
334337
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
335338
//
336339
// we'll have the following i4 output:
337340
//
338-
// expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
341+
// expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
339342
//
340343
// Emulating the above using i8 will give:
341344
//
342-
// %compressed_mask = [1, 1, 1, 1] (4 * i1)
343-
// %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8)
344-
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
345+
// %compressed_mask = [1, 1, 1, 0] (4 * i1)
346+
// %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
347+
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
345348
// %select_using_shifted_mask =
346-
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4)
347-
// %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8)
349+
// [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
350+
// %packed_data = [0x9A, 0xBC, 0xD6, 0x00] (4 * i8)
348351
//
349-
// Using the new mask to store %packed_data results in expected output.
352+
// Using the compressed mask to store %packed_data results in expected
353+
// output.
350354
FailureOr<Operation *> newMask =
351355
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
352356
if (failed(newMask))

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,15 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
202202

203203
// -----
204204

205-
func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
205+
func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
206206
%0 = memref.alloc() : memref<3x4xi8>
207207
%mask = vector.constant_mask [2] : vector<4xi1>
208208
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
209209
memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
210210
return %1 : vector<4xi8>
211211
}
212212
// Expect no conversions, i8 is supported.
213-
// CHECK: func @vector_cst_maskedload_i8(
213+
// CHECK: func @vector_maskedload_i8_constant_mask(
214214
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
215215
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
216216
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -220,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
220220
// CHECK-NEXT: return
221221

222222
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
223-
// CHECK32: func @vector_cst_maskedload_i8(
223+
// CHECK32: func @vector_maskedload_i8_constant_mask(
224224
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
225225
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
226226
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -236,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
236236

237237
// -----
238238

239-
func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
239+
func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
240240
%0 = memref.alloc() : memref<3x8xi4>
241241
%cst = arith.constant dense<0> : vector<3x8xi4>
242242
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -246,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
246246
return %2 : vector<3x8xi4>
247247
}
248248
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
249-
// CHECK: func @vector_cst_maskedload_i4(
249+
// CHECK: func @vector_maskedload_i4_constant_mask(
250250
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
251251
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
252252
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -260,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
260260
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
261261

262262
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
263-
// CHECK32: func @vector_cst_maskedload_i4(
263+
// CHECK32: func @vector_maskedload_i4_constant_mask(
264264
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
265265
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
266266
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -500,7 +500,6 @@ func.func @vector_maskedstore_i4(
500500
%value: vector<8xi4>) {
501501

502502
%0 = memref.alloc() : memref<3x8xi4>
503-
%cst = arith.constant dense<0> : vector<3x8xi4>
504503
%mask = vector.create_mask %num_elements_to_store : vector<8xi1>
505504
vector.maskedstore %0[%idx1, %idx2], %mask, %value :
506505
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
@@ -548,14 +547,14 @@ func.func @vector_maskedstore_i4(
548547

549548
// -----
550549

551-
func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
550+
func.func @vector_maskedstore_i8_constant_mask(%arg0: index, %arg1: index, %value: vector<8xi8>) {
552551
%0 = memref.alloc() : memref<3x8xi8>
553552
%mask = vector.constant_mask [4] : vector<8xi1>
554553
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
555554
return
556555
}
557556
// Expect no conversions, i8 is supported.
558-
// CHECK: func @vector_cst_maskedstore_i8(
557+
// CHECK: func @vector_maskedstore_i8_constant_mask(
559558
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
560559
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
561560
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -565,7 +564,7 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
565564
// CHECK-NEXT: return
566565

567566
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
568-
// CHECK32: func @vector_cst_maskedstore_i8(
567+
// CHECK32: func @vector_maskedstore_i8_constant_mask(
569568
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
570569
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
571570
// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -582,21 +581,20 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
582581

583582
// -----
584583

585-
func.func @vector_cst_maskedstore_i4(
584+
func.func @vector_maskedstore_i4_constant_mask(
586585
%idx_1: index,
587586
%idx_2: index,
588587
%val_to_store: vector<8xi4>) {
589588

590589
%0 = memref.alloc() : memref<3x8xi4>
591-
%cst = arith.constant dense<0> : vector<3x8xi4>
592590
%mask = vector.constant_mask [4] : vector<8xi1>
593591
vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
594592
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
595593
return
596594
}
597595

598596
// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
599-
// CHECK-LABEL: func.func @vector_cst_maskedstore_i4(
597+
// CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask(
600598
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
601599
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
602600
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -606,13 +604,13 @@ func.func @vector_cst_maskedstore_i4(
606604
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
607605
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
608606
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
609-
// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
610-
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
611-
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
612-
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
607+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
608+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
609+
// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
610+
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
613611

614612
// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
615-
// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4(
613+
// CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask(
616614
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
617615
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
618616
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -622,7 +620,7 @@ func.func @vector_cst_maskedstore_i4(
622620
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
623621
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
624622
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
625-
// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
626-
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
627-
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
628-
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
623+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
624+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625+
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
626+
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

0 commit comments

Comments
 (0)