1
1
// RUN: gc-opt --split-input-file -any-tilable-fusion %s
2
2
3
- func.func @mlp (%arg0: tensor <128 x512 xbf16 >, %arg1: tensor <32 x8 x16 x32 xbf16 >, %arg2: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
3
+ module {
4
+ func.func @mlp (%arg0: tensor <128 x512 xbf16 >, %arg1: tensor <32 x8 x16 x32 xbf16 >, %arg2: tensor <256 xbf16 >) -> tensor <128 x256 xbf16 > {
4
5
%c32 = arith.constant 32 : index
5
6
%c512 = arith.constant 512 : index
6
7
%c128 = arith.constant 128 : index
@@ -58,4 +59,83 @@ func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg
58
59
%3 = linalg.add ins (%2 , %broadcasted : tensor <128 x256 xbf16 >, tensor <128 x256 xbf16 >) outs (%0 : tensor <128 x256 xbf16 >) -> tensor <128 x256 xbf16 >
59
60
%4 = linalg.exp ins (%3 : tensor <128 x256 xbf16 >) outs (%0 : tensor <128 x256 xbf16 >) -> tensor <128 x256 xbf16 >
60
61
return %4 : tensor <128 x256 xbf16 >
61
- }
62
+ }
63
+ }
64
+
65
+ // -----
66
+
67
+ #map = affine_map <(d0 ) -> (d0 * 128 )>
68
+ module {
69
+ func.func @fuse_multiple_consumer (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >, %arg3: tensor <256 x256 xf32 >) -> (tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) {
70
+ %c0 = arith.constant 0 : index
71
+ %c64 = arith.constant 64 : index
72
+ %c128 = arith.constant 128 : index
73
+ %cst = arith.constant 0.000000e+00 : f32
74
+ %dest0 = tensor.empty () : tensor <256 x256 xf32 >
75
+ %dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
76
+ %1 = scf.forall (%arg4 , %arg5 ) in (2 , 2 ) shared_outs (%arg6 = %dest1 ) -> tensor <256 x256 xf32 > {
77
+ %iv0 = affine.apply #map (%arg4 )
78
+ %iv1 = affine.apply #map (%arg5 )
79
+ %extracted_slice_1 = tensor.extract_slice %arg6 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <128 x128 xf32 >
80
+ %extracted_slice_2 = tensor.extract_slice %arg0 [%iv0 , 0 ] [128 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <128 x512 xf32 >
81
+ %extracted_slice_3 = tensor.extract_slice %arg1 [0 , %iv1 ] [512 , 128 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x128 xf32 >
82
+ %2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args (%arg8 = %extracted_slice_1 ) -> (tensor <128 x128 xf32 >) {
83
+ %3 = scf.for %arg9 = %c0 to %c128 step %c64 iter_args (%arg10 = %arg8 ) -> (tensor <128 x128 xf32 >) {
84
+ %extracted_slice_4 = tensor.extract_slice %arg10 [%arg7 , %arg9 ] [64 , 64 ] [1 , 1 ] : tensor <128 x128 xf32 > to tensor <64 x64 xf32 >
85
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2 [%arg7 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <128 x512 xf32 > to tensor <64 x512 xf32 >
86
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3 [0 , %arg9 ] [512 , 64 ] [1 , 1 ] : tensor <512 x128 xf32 > to tensor <512 x64 xf32 >
87
+ %4 = linalg.matmul ins (%extracted_slice_5 , %extracted_slice_6 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice_4 : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
88
+ %insert_slice = tensor.insert_slice %4 into %arg10 [%arg7 , %arg9 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <128 x128 xf32 >
89
+ scf.yield %insert_slice : tensor <128 x128 xf32 >
90
+ }
91
+ scf.yield %3 : tensor <128 x128 xf32 >
92
+ }
93
+ scf.forall.in_parallel {
94
+ tensor.parallel_insert_slice %2 into %arg6 [%iv0 , %iv1 ] [128 , 128 ] [1 , 1 ] : tensor <128 x128 xf32 > into tensor <256 x256 xf32 >
95
+ }
96
+ }
97
+ %5 = linalg.add ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
98
+ %6 = linalg.add ins (%1 , %arg3 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
99
+ return %5 , %6 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >
100
+ }
101
+ }
102
+
103
+ // -----
104
+
105
+ #map = affine_map <(d0 ) -> (d0 * 128 )>
106
+ module {
107
+ func.func @fuse_reduce (%arg0: tensor <256 x512 xf32 >, %arg1: tensor <512 x256 xf32 >, %arg2: tensor <256 x256 xf32 >) -> tensor <256 xf32 > {
108
+ %c0 = arith.constant 0 : index
109
+ %c64 = arith.constant 64 : index
110
+ %c128 = arith.constant 128 : index
111
+ %c256 = arith.constant 256 : index
112
+ %cst = arith.constant 0.000000e+00 : f32
113
+ %dest0 = tensor.empty () : tensor <256 x256 xf32 >
114
+ %dest1 = linalg.fill ins (%cst : f32 ) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
115
+ %1 = scf.forall (%arg3 , %arg4 ) in (2 , 1 ) shared_outs (%arg5 = %dest1 ) -> tensor <256 x256 xf32 > {
116
+ %iv0 = affine.apply #map (%arg3 )
117
+ %iv1 = affine.apply #map (%arg4 )
118
+ %extracted_slice_1 = tensor.extract_slice %arg5 [%iv0 , %iv1 ] [128 , 256 ] [1 , 1 ] : tensor <256 x256 xf32 > to tensor <128 x256 xf32 >
119
+ %extracted_slice_2 = tensor.extract_slice %arg0 [%iv0 , 0 ] [128 , 512 ] [1 , 1 ] : tensor <256 x512 xf32 > to tensor <128 x512 xf32 >
120
+ %extracted_slice_3 = tensor.extract_slice %arg1 [0 , %iv1 ] [512 , 256 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x256 xf32 >
121
+ %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args (%arg7 = %extracted_slice_1 ) -> (tensor <128 x256 xf32 >) {
122
+ %3 = scf.for %arg8 = %c0 to %c256 step %c64 iter_args (%arg9 = %arg7 ) -> (tensor <128 x256 xf32 >) {
123
+ %extracted_slice_4 = tensor.extract_slice %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <128 x256 xf32 > to tensor <64 x64 xf32 >
124
+ %extracted_slice_5 = tensor.extract_slice %extracted_slice_2 [%arg6 , 0 ] [64 , 512 ] [1 , 1 ] : tensor <128 x512 xf32 > to tensor <64 x512 xf32 >
125
+ %extracted_slice_6 = tensor.extract_slice %extracted_slice_3 [0 , %arg8 ] [512 , 64 ] [1 , 1 ] : tensor <512 x256 xf32 > to tensor <512 x64 xf32 >
126
+ %4 = linalg.matmul ins (%extracted_slice_5 , %extracted_slice_6 : tensor <64 x512 xf32 >, tensor <512 x64 xf32 >) outs (%extracted_slice_4 : tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 >
127
+ %insert_slice = tensor.insert_slice %4 into %arg9 [%arg6 , %arg8 ] [64 , 64 ] [1 , 1 ] : tensor <64 x64 xf32 > into tensor <128 x256 xf32 >
128
+ scf.yield %insert_slice : tensor <128 x256 xf32 >
129
+ }
130
+ scf.yield %3 : tensor <128 x256 xf32 >
131
+ }
132
+ scf.forall.in_parallel {
133
+ tensor.parallel_insert_slice %2 into %arg5 [%iv0 , %iv1 ] [128 , 256 ] [1 , 1 ] : tensor <128 x256 xf32 > into tensor <256 x256 xf32 >
134
+ }
135
+ }
136
+ %5 = linalg.add ins (%1 , %arg2 : tensor <256 x256 xf32 >, tensor <256 x256 xf32 >) outs (%dest0 : tensor <256 x256 xf32 >) -> tensor <256 x256 xf32 >
137
+ %dest2 = tensor.empty () : tensor <256 xf32 >
138
+ %6 = linalg.reduce { arith.addf } ins (%5 : tensor <256 x256 xf32 >) outs (%dest2 : tensor <256 xf32 >) dimensions = [1 ]
139
+ return %6 : tensor <256 xf32 >
140
+ }
141
+ }
0 commit comments