1
+ // RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
2
+
3
+ #map = affine_map <(d0 ) -> (d0 * 128 )>
4
+ module {
5
+ func.func @gemm_fill_fusion_multi_level_extract_slice (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 > {
6
+ %c0 = arith.constant 0 : index
7
+ %c64 = arith.constant 64 : index
8
+ %c128 = arith.constant 128 : index
9
+ %cst = arith.constant 0.000000e+00 : f32
10
+ %dest0 = tensor.empty () : tensor <256 x256 xf32 >
11
+ %dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
12
+ %1 = scf.forall (%arg3 , %arg4 ) in (2 , 2 ) shared_outs (%arg5 = %dest1 ) -> tensor <256 x256 xf32 > {
13
+ %iv0 = affine.apply #map (%arg3 )
14
+ %iv1 = affine.apply #map (%arg4 )
15
+ %extracted_slice_1 = tensor.extract_slice %arg5 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <128 x128 xf32 >
16
+ %extracted_slice_2 = tensor.extract_slice %arg0 [%iv0 , 0 ] [128 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <128 x512 xf32 >
17
+ %extracted_slice_3 = tensor.extract_slice %arg1 [0 , %iv1 ] [512 , 128 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x128 xf32 >
18
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args (%arg7 = %extracted_slice_1 ) -> (tensor <128 x128 xf32 >) {
19
+ %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args (%arg9 = %arg7 ) -> (tensor <128 x128 xf32 >) {
20
+ %extracted_slice_4 = tensor.extract_slice %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <128 x128 xf32 > to tensor <64 x64 xf32 >
21
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2 [%arg6 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <128 x512 xf32 > to tensor <64 x512 xf32 >
22
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3 [0 , %arg8 ] [512 , 64 ] [1 , 1 ] : tensor <512 x128 xf32 > to tensor <512 x64 xf32 >
23
+ %4 = linalg.matmul ins (%extracted_slice_5 , %extracted_slice_6 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice_4 : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
24
+ %insert_slice = tensor.insert_slice %4 into %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <128 x128 xf32 >
25
+ scf.yield %insert_slice : tensor <128 x128 xf32 >
26
+ }
27
+ scf.yield %3 : tensor <128 x128 xf32 >
28
+ }
29
+ scf.forall.in_parallel {
30
+ tensor.parallel_insert_slice %2 into %arg5 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <128 x128 xf32 > into tensor <256 x256 xf32 >
31
+ }
32
+ }
33
+ return %1 : tensor <256 x256 xf32 >
34
+ }
35
+ }
36
+
37
+ module attributes {transform.with_named_sequence } {
38
+ transform.named_sequence @__transform_main (%arg1 : !transform.any_op {transform.readonly }) {
39
+ %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %arg1
40
+ : (!transform.any_op ) -> !transform.any_op
41
+ %yield = transform.get_producer_of_operand %matmul [2 ]
42
+ : (!transform.any_op ) -> !transform.any_op
43
+ %a , %b = transform.test.fuse_producer %yield
44
+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
45
+ transform.yield
46
+ }
47
+ }
48
+
49
+ // CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
50
+ // CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
51
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
52
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
53
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
54
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
55
+ // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
56
+ // CHECK: %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
57
+ // CHECK-SAME: shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
58
+ // CHECK-SAME: {
59
+ // CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
60
+ // CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
61
+ // CHECK: %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
62
+ // CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
63
+ // CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
64
+ // CHECK: %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
65
+ // CHECK-SAME: iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
66
+ // CHECK-SAME: {
67
+ // CHECK: %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
68
+ // CHECK-SAME: iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
69
+ // CHECK-SAME: {
70
+ // CHECK: %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
71
+ // CHECK: %[[TILED_FILL_OUT:.*]] = linalg.fill
72
+ // CHECK-SAME: outs(%[[FILL_OUT_SLICE1]] :
73
+ // CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
74
+ // CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
75
+ // CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
76
+ // CHECK-SAME: outs(%[[TILED_FILL_OUT]] :
77
+ // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
78
+ // CHECK: scf.yield %[[INSERT_MAT]] :
79
+ // CHECK: }
80
+ // CHECK: scf.yield %[[LOOP_RESULT2]] :
81
+ // CHECK: }
82
+ // CHECK: scf.forall.in_parallel {
83
+ // CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
84
+ // CHECK: }
85
+ // CHECK: }
86
+ // CHECK: return %[[FORALL_RESULT]] :
0 commit comments