@@ -447,41 +447,68 @@ transform.sequence failures(propagate) {
447
447
448
448
// -----
449
449
450
- func.func @vectorize_dynamic_matmul (%A: memref <?x?xf32 >, %B: memref <?x?xf32 >, %C: memref <?x?xf32 >) {
450
+ func.func @matmul (%A: memref <?x?xf32 >, %B: memref <?x?xf32 >, %C: memref <?x?xf32 >) {
451
451
linalg.matmul ins (%A , %B: memref <?x?xf32 >, memref <?x?xf32 >)
452
452
outs (%C: memref <?x?xf32 >)
453
453
return
454
454
}
455
455
456
- // CHECK-LABEL: func.func @vectorize_dynamic_matmul (
457
- // CHECK-SAME: %[[VAL_0 :.*]]: memref<?x?xf32>, %[[VAL_1 :.*]]: memref<?x?xf32>, %[[VAL_2 :.*]]: memref<?x?xf32>) {
456
+ // CHECK-LABEL: func.func @matmul (
457
+ // CHECK-SAME: %[[A :.*]]: memref<?x?xf32>, %[[B :.*]]: memref<?x?xf32>, %[[C :.*]]: memref<?x?xf32>) {
458
458
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
459
- // CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[VAL_0 ]], %[[VAL_3]] : memref<?x?xf32>
459
+ // CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A ]], %[[VAL_3]] : memref<?x?xf32>
460
460
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
461
- // CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[VAL_1 ]], %[[VAL_5]] : memref<?x?xf32>
461
+ // CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B ]], %[[VAL_5]] : memref<?x?xf32>
462
462
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
463
- // CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[VAL_0]], %[[VAL_7]] : memref<?x?xf32>
464
- // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : index
465
- // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
466
- // CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
467
- // CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_10]] {in_bounds = [true, true, true], permutation_map = #map} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
468
- // CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32
469
- // CHECK: %[[VAL_14:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
470
- // CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_14]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_13]] {in_bounds = [true, true, true], permutation_map = #map1} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
471
- // CHECK: %[[VAL_16:.*]] = arith.constant 0.000000e+00 : f32
472
- // CHECK: %[[VAL_17:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
473
- // CHECK: %[[VAL_18:.*]] = vector.mask %[[VAL_17]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_9]]], %[[VAL_16]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
474
- // CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_12]], %[[VAL_15]] : vector<8x16x4xf32>
475
- // CHECK: %[[VAL_20:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
476
- // CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.multi_reduction <add>, %[[VAL_19]], %[[VAL_18]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
477
- // CHECK: %[[VAL_22:.*]] = arith.constant 0 : index
478
- // CHECK: vector.mask %[[VAL_17]] { vector.transfer_write %[[VAL_21]], %[[VAL_2]]{{\[}}%[[VAL_22]], %[[VAL_22]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
479
- // CHECK: return
480
- // CHECK: }
463
+ // CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
464
+ // CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
465
+ // CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<8x4xi1> -> vector<8x16x4xf32>
466
+ // CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x16xi1>
467
+ // CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x16x4xf32> } : vector<4x16xi1> -> vector<8x16x4xf32>
468
+ // CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x16xi1>
469
+ // CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32> } : vector<8x16xi1> -> vector<8x16xf32>
470
+ // CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x16x4xf32>
471
+ // CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x16x4xi1>
472
+ // CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x16x4xf32> to vector<8x16xf32> } : vector<8x16x4xi1> -> vector<8x16xf32>
473
+ // CHECK: %[[C2:.*]] = arith.constant 0 : index
474
+ // CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x16xf32>, memref<?x?xf32> } : vector<8x16xi1>
481
475
482
476
transform.sequence failures (propagate ) {
483
477
^bb1 (%arg1: !transform.any_op ):
484
- %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
485
- transform.structured.masked_vectorize %0 vector_sizes [8 , 16 , 4 ] : !transform.any_op
478
+ %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
479
+ transform.structured.masked_vectorize %matmul vector_sizes [8 , 16 , 4 ] : !transform.any_op
480
+ }
481
+
482
+ // -----
483
+
484
+ func.func @matmul_scalable (%A: memref <?x?xf32 >, %B: memref <?x?xf32 >, %C: memref <?x?xf32 >) {
485
+ linalg.matmul ins (%A , %B: memref <?x?xf32 >, memref <?x?xf32 >)
486
+ outs (%C: memref <?x?xf32 >)
487
+ return
486
488
}
487
489
490
+ // CHECK-LABEL: func.func @matmul_scalable(
491
+ // CHECK-SAME: %[[A:.*]]: memref<?x?xf32>, %[[B:.*]]: memref<?x?xf32>, %[[C:.*]]: memref<?x?xf32>) {
492
+ // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index
493
+ // CHECK-DAG: %[[VAL_4:.*]] = memref.dim %[[A]], %[[VAL_3]] : memref<?x?xf32>
494
+ // CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
495
+ // CHECK-DAG: %[[VAL_6:.*]] = memref.dim %[[B]], %[[VAL_5]] : memref<?x?xf32>
496
+ // CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
497
+ // CHECK-DAG: %[[VAL_8:.*]] = memref.dim %[[A]], %[[VAL_7]] : memref<?x?xf32>
498
+ // CHECK: %[[MASK_A:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_8]] : vector<8x4xi1>
499
+ // CHECK: %[[LOAD_A:.*]] = vector.mask %[[MASK_A]] { vector.transfer_read %[[A]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<8x4xi1> -> vector<8x[16]x4xf32>
500
+ // CHECK: %[[MASK_B:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_6]] : vector<4x[16]xi1>
501
+ // CHECK: %[[LOAD_B:.*]] = vector.mask %[[MASK_B]] { vector.transfer_read %[[B]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<8x[16]x4xf32> } : vector<4x[16]xi1> -> vector<8x[16]x4xf32>
502
+ // CHECK: %[[MASK_C:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<8x[16]xi1>
503
+ // CHECK: %[[LOAD_C:.*]] = vector.mask %[[MASK_C]] { vector.transfer_read %[[C]]{{\[}}%{{.*}}, %{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x[16]xf32> } : vector<8x[16]xi1> -> vector<8x[16]xf32>
504
+ // CHECK: %[[MULF:.*]] = arith.mulf %[[LOAD_A]], %[[LOAD_B]] : vector<8x[16]x4xf32>
505
+ // CHECK: %[[MASK_MULIT_RED:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : vector<8x[16]x4xi1>
506
+ // CHECK: %[[MULTI_RED:.*]] = vector.mask %[[MASK_MULIT_RED]] { vector.multi_reduction <add>, %[[MULF]], %[[LOAD_C]] [2] : vector<8x[16]x4xf32> to vector<8x[16]xf32> } : vector<8x[16]x4xi1> -> vector<8x[16]xf32>
507
+ // CHECK: %[[C2:.*]] = arith.constant 0 : index
508
+ // CHECK: vector.mask %[[MASK_C]] { vector.transfer_write %[[MULTI_RED]], %[[C]]{{\[}}%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<8x[16]xf32>, memref<?x?xf32> } : vector<8x[16]xi1>
509
+
510
+ transform.sequence failures (propagate ) {
511
+ ^bb1 (%arg1: !transform.any_op ):
512
+ %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
513
+ transform.structured.masked_vectorize %matmul vector_sizes [8 , [16 ], 4 ] : !transform.any_op
514
+ }
0 commit comments