1
1
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
2
2
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
3
3
4
+ ///----------------------------------------------------------------------------------------
5
+ /// vector.load
6
+ ///----------------------------------------------------------------------------------------
7
+
4
8
func.func @vector_load_i8 (%arg1: index , %arg2: index ) -> vector <4 xi8 > {
5
9
%0 = memref.alloc () : memref <3 x4 xi8 >
6
10
%1 = vector.load %0 [%arg1 , %arg2 ] : memref <3 x4 xi8 >, vector <4 xi8 >
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
82
86
83
87
// -----
84
88
89
+ ///----------------------------------------------------------------------------------------
90
+ /// vector.transfer_read
91
+ ///----------------------------------------------------------------------------------------
92
+
85
93
func.func @vector_transfer_read_i4 (%arg1: index , %arg2: index ) -> vector <8 xi4 > {
86
94
%c0 = arith.constant 0 : i4
87
95
%0 = memref.alloc () : memref <3 x8 xi4 >
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
111
119
112
120
// -----
113
121
122
+ ///----------------------------------------------------------------------------------------
123
+ /// vector.maskedload
124
+ ///----------------------------------------------------------------------------------------
125
+
114
126
func.func @vector_maskedload_i8 (%arg1: index , %arg2: index , %arg3: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
115
127
%0 = memref.alloc () : memref <3 x4 xi8 >
116
128
%mask = vector.create_mask %arg3 : vector <4 xi1 >
@@ -190,15 +202,15 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
190
202
191
203
// -----
192
204
193
- func.func @vector_cst_maskedload_i8 (%arg1: index , %arg2: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
205
+ func.func @vector_maskedload_i8_constant_mask (%arg1: index , %arg2: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
194
206
%0 = memref.alloc () : memref <3 x4 xi8 >
195
207
%mask = vector.constant_mask [2 ] : vector <4 xi1 >
196
208
%1 = vector.maskedload %0 [%arg1 , %arg2 ], %mask , %passthru :
197
209
memref <3 x4 xi8 >, vector <4 xi1 >, vector <4 xi8 > into vector <4 xi8 >
198
210
return %1 : vector <4 xi8 >
199
211
}
200
212
// Expect no conversions, i8 is supported.
201
- // CHECK: func @vector_cst_maskedload_i8 (
213
+ // CHECK: func @vector_maskedload_i8_constant_mask (
202
214
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
203
215
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
204
216
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -208,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
208
220
// CHECK-NEXT: return
209
221
210
222
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
211
- // CHECK32: func @vector_cst_maskedload_i8 (
223
+ // CHECK32: func @vector_maskedload_i8_constant_mask (
212
224
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
213
225
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
214
226
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -224,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
224
236
225
237
// -----
226
238
227
- func.func @vector_cst_maskedload_i4 (%arg1: index , %arg2: index , %passthru: vector <8 xi4 >) -> vector <3 x8 xi4 > {
239
+ func.func @vector_maskedload_i4_constant_mask (%arg1: index , %arg2: index , %passthru: vector <8 xi4 >) -> vector <3 x8 xi4 > {
228
240
%0 = memref.alloc () : memref <3 x8 xi4 >
229
241
%cst = arith.constant dense <0 > : vector <3 x8 xi4 >
230
242
%mask = vector.constant_mask [4 ] : vector <8 xi1 >
@@ -234,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
234
246
return %2 : vector <3 x8 xi4 >
235
247
}
236
248
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
237
- // CHECK: func @vector_cst_maskedload_i4 (
249
+ // CHECK: func @vector_maskedload_i4_constant_mask (
238
250
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
239
251
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
240
252
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -248,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
248
260
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
249
261
250
262
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
251
- // CHECK32: func @vector_cst_maskedload_i4 (
263
+ // CHECK32: func @vector_maskedload_i4_constant_mask (
252
264
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
253
265
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
254
266
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
263
275
264
276
// -----
265
277
278
+ ///----------------------------------------------------------------------------------------
279
+ /// vector.extract -> vector.masked_load
280
+ ///----------------------------------------------------------------------------------------
281
+
266
282
func.func @vector_extract_maskedload_i4 (%arg1: index ) -> vector <8 x8 x16 xi4 > {
267
283
%0 = memref.alloc () : memref <8 x8 x16 xi4 >
268
284
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
353
369
354
370
// -----
355
371
372
+ ///----------------------------------------------------------------------------------------
373
+ /// vector.store
374
+ ///----------------------------------------------------------------------------------------
375
+
356
376
func.func @vector_store_i8 (%arg0: vector <8 xi8 >, %arg1: index , %arg2: index ) {
357
377
%0 = memref.alloc () : memref <4 x8 xi8 >
358
378
vector.store %arg0 , %0 [%arg1 , %arg2 ] :memref <4 x8 xi8 >, vector <8 xi8 >
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
431
451
432
452
// -----
433
453
454
+ ///----------------------------------------------------------------------------------------
455
+ /// vector.maskedstore
456
+ ///----------------------------------------------------------------------------------------
457
+
434
458
func.func @vector_maskedstore_i8 (%arg0: index , %arg1: index , %arg2: index , %value: vector <8 xi8 >) {
435
459
%0 = memref.alloc () : memref <3 x8 xi8 >
436
460
%mask = vector.create_mask %arg2 : vector <8 xi1 >
@@ -469,14 +493,68 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
469
493
470
494
// -----
471
495
472
- func.func @vector_cst_maskedstore_i8 (%arg0: index , %arg1: index , %value: vector <8 xi8 >) {
496
+ func.func @vector_maskedstore_i4 (
497
+ %idx1: index ,
498
+ %idx2: index ,
499
+ %num_elements_to_store: index ,
500
+ %value: vector <8 xi4 >) {
501
+
502
+ %0 = memref.alloc () : memref <3 x8 xi4 >
503
+ %mask = vector.create_mask %num_elements_to_store : vector <8 xi1 >
504
+ vector.maskedstore %0 [%idx1 , %idx2 ], %mask , %value :
505
+ memref <3 x8 xi4 >, vector <8 xi1 >, vector <8 xi4 >
506
+ return
507
+ }
508
+ // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
509
+ // CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
510
+
511
+ // CHECK-LABEL: func.func @vector_maskedstore_i4(
512
+ // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
513
+ // CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
514
+ // CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
515
+ // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
516
+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
517
+ // CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
518
+ // CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]]
519
+ // CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]]()[%[[NUM_EL_TO_STORE]]]
520
+ // CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
521
+ // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
522
+ // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
523
+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
524
+ // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
525
+ // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
526
+ // CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
527
+
528
+ // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
529
+ // CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
530
+
531
+ // CHECK32-LABEL: func.func @vector_maskedstore_i4(
532
+ // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
533
+ // CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
534
+ // CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
535
+ // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
536
+ // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
537
+ // CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
538
+ // CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]]()[%[[IDX_1]], %[[IDX_2]]]
539
+ // CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]]()[%[[NUM_EL_TO_STORE]]]
540
+ // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
541
+ // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
542
+ // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
543
+ // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
544
+ // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
545
+ // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
546
+ // CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
547
+
548
+ // -----
549
+
550
+ func.func @vector_maskedstore_i8_constant_mask (%arg0: index , %arg1: index , %value: vector <8 xi8 >) {
473
551
%0 = memref.alloc () : memref <3 x8 xi8 >
474
552
%mask = vector.constant_mask [4 ] : vector <8 xi1 >
475
553
vector.maskedstore %0 [%arg0 , %arg1 ], %mask , %value : memref <3 x8 xi8 >, vector <8 xi1 >, vector <8 xi8 >
476
554
return
477
555
}
478
556
// Expect no conversions, i8 is supported.
479
- // CHECK: func @vector_cst_maskedstore_i8 (
557
+ // CHECK: func @vector_maskedstore_i8_constant_mask (
480
558
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
481
559
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
482
560
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -486,7 +564,7 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
486
564
// CHECK-NEXT: return
487
565
488
566
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
489
- // CHECK32: func @vector_cst_maskedstore_i8 (
567
+ // CHECK32: func @vector_maskedstore_i8_constant_mask (
490
568
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
491
569
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
492
570
// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -500,3 +578,49 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
500
578
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
501
579
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
502
580
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
581
+
582
+ // -----
583
+
584
+ func.func @vector_maskedstore_i4_constant_mask (
585
+ %idx_1: index ,
586
+ %idx_2: index ,
587
+ %val_to_store: vector <8 xi4 >) {
588
+
589
+ %0 = memref.alloc () : memref <3 x8 xi4 >
590
+ %mask = vector.constant_mask [4 ] : vector <8 xi1 >
591
+ vector.maskedstore %0 [%idx_1 , %idx_2 ], %mask , %val_to_store :
592
+ memref <3 x8 xi4 >, vector <8 xi1 >, vector <8 xi4 >
593
+ return
594
+ }
595
+
596
+ // CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
597
+ // CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask(
598
+ // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
599
+ // CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
600
+ // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
601
+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
602
+ // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
603
+ // CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]]()[%[[IDX_1]], %[[IDX_2]]]
604
+ // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
605
+ // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
606
+ // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
607
+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
608
+ // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
609
+ // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
610
+ // CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
611
+
612
+ // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
613
+ // CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask(
614
+ // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
615
+ // CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
616
+ // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
617
+ // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
618
+ // CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
619
+ // CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]]()[%[[IDX_1]], %[[IDX_2]]]
620
+ // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
621
+ // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
622
+ // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
623
+ // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
624
+ // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625
+ // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
626
+ // CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
0 commit comments