@@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361
361
/// vector.store
362
362
///----------------------------------------------------------------------------------------
363
363
364
+ // -----
365
+
366
+ // Most basic example to demonstrate where partial stores are not needed.
367
+
368
+ func.func @vector_store_i2_const_index_no_partial_store (%arg0: vector <4 xi2 >) {
369
+ %0 = memref.alloc () : memref <13 xi2 >
370
+ %c4 = arith.constant 4 : index
371
+ vector.store %arg0 , %0 [%c4 ] : memref <13 xi2 >, vector <4 xi2 >
372
+ return
373
+ }
374
+ // CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
375
+ // CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
376
+ // CHECK-NOT: memref.generic_atomic_rmw
377
+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
378
+ // CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
379
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
380
+ // CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>
381
+
382
+ // -----
383
+
384
+ // Small modification of the example above to demonstrate where partial stores
385
+ // are needed.
386
+
387
+ func.func @vector_store_i2_const_index_two_partial_stores (%arg0: vector <4 xi2 >) {
388
+ %0 = memref.alloc () : memref <13 xi2 >
389
+ %c3 = arith.constant 3 : index
390
+ vector.store %arg0 , %0 [%c3 ] : memref <13 xi2 >, vector <4 xi2 >
391
+ return
392
+ }
393
+
394
+ // CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
395
+ // CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
396
+ // CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
397
+
398
+ // First atomic RMW:
399
+ // CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
400
+ // CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
401
+ // CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
402
+ // CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
403
+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
404
+ // CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
405
+ // CHECK: ^bb0(%[[VAL_8:.*]]: i8):
406
+ // CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
407
+ // CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
408
+ // CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
409
+ // CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
410
+ // CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
411
+ // CHECK: memref.atomic_yield %[[RES_1]] : i8
412
+ // CHECK: }
413
+
414
+ // Second atomic RMW:
415
+ // CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
416
+ // CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
417
+ // CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
418
+ // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
419
+ // CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
420
+ // CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
421
+ // CHECK: ^bb0(%[[VAL_20:.*]]: i8):
422
+ // CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
423
+ // CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
424
+ // CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
425
+ // CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
426
+ // CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
427
+ // CHECK: memref.atomic_yield %[[RES_2]] : i8
428
+ // CHECK: }
429
+
430
+ // -----
431
+
364
432
func.func @vector_store_i2_const_index_two_partial_stores (%arg0: vector <3 xi2 >) {
365
433
%src = memref.alloc () : memref <3 x3 xi2 >
366
434
%c0 = arith.constant 0 : index
0 commit comments