@@ -95,6 +95,83 @@ module {
95
95
96
96
// -----
97
97
98
+ #map = affine_map <(d0 ) -> (d0 * 32 )>
99
+ #map1 = affine_map <(d0 ) -> (d0 * 16 )>
100
+ module {
101
+ /// CHECK-LABEL: @fuse_mlp_vnni
102
+ func.func @fuse_mlp_vnni (%arg0: tensor <128 x1024 xbf16 >, %arg1: tensor <1024 x512 xbf16 >, %arg2: tensor <512 xbf16 >) -> tensor <128 x512 xbf16 > attributes {llvm.emit_c_interface } {
103
+ %c2 = arith.constant 2 : index
104
+ %c64 = arith.constant 64 : index
105
+ %c0 = arith.constant 0 : index
106
+ %cst = arith.constant dense <0.000000e+00 > : tensor <128 x512 xbf16 >
107
+ /// CHECK: tensor.empty
108
+ %0 = tensor.empty () : tensor <128 x512 xbf16 >
109
+ /// CHECK: tensor.empty
110
+ %1 = tensor.empty () : tensor <16 x64 x16 x32 xbf16 >
111
+ %pack = tensor.pack %arg1 outer_dims_perm = [1 , 0 ] inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 32 ] into %1 : tensor <1024 x512 xbf16 > -> tensor <16 x64 x16 x32 xbf16 >
112
+ /// CHECK: tensor.empty
113
+ %2 = tensor.empty () : tensor <16 x64 x8 x32 x2 xbf16 >
114
+ %pack_0 = tensor.pack %pack inner_dims_pos = [2 ] inner_tiles = [2 ] into %2 : tensor <16 x64 x16 x32 xbf16 > -> tensor <16 x64 x8 x32 x2 xbf16 >
115
+ /// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%{{.*}}) in (16)
116
+ %3 = scf.forall (%arg3 ) in (16 ) shared_outs (%arg4 = %0 ) -> (tensor <128 x512 xbf16 >) {
117
+ %9 = affine.apply #map (%arg3 )
118
+ %extracted_slice = tensor.extract_slice %arg4 [0 , %9 ] [128 , 32 ] [1 , 1 ] : tensor <128 x512 xbf16 > to tensor <128 x32 xbf16 >
119
+ /// CHECK: tensor.empty
120
+ %10 = tensor.empty () : tensor <128 x32 xf32 >
121
+ /// CHECK: linalg.copy
122
+ %11 = linalg.copy ins (%extracted_slice : tensor <128 x32 xbf16 >) outs (%10 : tensor <128 x32 xf32 >) -> tensor <128 x32 xf32 >
123
+ /// CHECK: %[[TMP_RESULT:.*]]:2 = scf.for
124
+ %12:2 = scf.for %arg5 = %c0 to %c64 step %c2 iter_args (%arg6 = %11 , %arg7 = %extracted_slice ) -> (tensor <128 x32 xf32 >, tensor <128 x32 xbf16 >) {
125
+ %14 = affine.apply #map1 (%arg5 )
126
+ %extracted_slice_1 = tensor.extract_slice %arg0 [0 , %14 ] [128 , 32 ] [1 , 1 ] : tensor <128 x1024 xbf16 > to tensor <128 x32 xbf16 >
127
+ /// CHECK: %[[PACK_OUT:.*]] = tensor.pack
128
+ /// CHECK: %[[PACK_OUT_VNNI:.*]] = tensor.pack %[[PACK_OUT]]
129
+ %extracted_slice_2 = tensor.extract_slice %pack_0 [%arg3 , %arg5 , 0 , 0 , 0 ] [1 , 2 , 8 , 32 , 2 ] [1 , 1 , 1 , 1 , 1 ] : tensor <16 x64 x8 x32 x2 xbf16 > to tensor <1 x2 x8 x32 x2 xbf16 >
130
+ /// CHECK: %[[COLLAPSE_OUT:.*]] = tensor.collapse_shape %[[PACK_OUT_VNNI]]
131
+ %collapsed = tensor.collapse_shape %extracted_slice_2 [[0 , 1 ], [2 ], [3 ], [4 ]] : tensor <1 x2 x8 x32 x2 xbf16 > into tensor <2 x8 x32 x2 xbf16 >
132
+ /// CHECK: %[[EXPAND_OUT:.*]] = tensor.expand_shape
133
+ %expanded = tensor.expand_shape %extracted_slice_1 [[0 ], [1 , 2 ]] output_shape [128 , 2 , 16 ] : tensor <128 x32 xbf16 > into tensor <128 x2 x16 xbf16 >
134
+ %15 = tensor.empty () : tensor <2 x128 x16 xbf16 >
135
+ /// CHECK: %[[TRANSPOSE_OUT:.*]] = linalg.transpose ins(%[[EXPAND_OUT]] :
136
+ %transposed = linalg.transpose ins (%expanded : tensor <128 x2 x16 xbf16 >) outs (%15 : tensor <2 x128 x16 xbf16 >) permutation = [1 , 0 , 2 ]
137
+ /// CHECK: %[[MATMUL_OUT:.*]] = linalgx.batch_reduce_matmul_vnni ins(%[[TRANSPOSE_OUT]], %[[COLLAPSE_OUT]] :
138
+ %16 = linalgx.batch_reduce_matmul_vnni ins (%transposed , %collapsed : tensor <2 x128 x16 xbf16 >, tensor <2 x8 x32 x2 xbf16 >) outs (%arg6 : tensor <128 x32 xf32 >) -> tensor <128 x32 xf32 >
139
+ %17 = arith.addi %arg5 , %c2 : index
140
+ %18 = arith.cmpi sge , %17 , %c64 : index
141
+ /// CHECK: %[[IF_RESULT:.*]] = scf.if
142
+ %19 = scf.if %18 -> (tensor <128 x32 xbf16 >) {
143
+ %20 = linalg.copy ins (%16 : tensor <128 x32 xf32 >) outs (%arg7 : tensor <128 x32 xbf16 >) -> tensor <128 x32 xbf16 >
144
+ scf.yield %20 : tensor <128 x32 xbf16 >
145
+ } else {
146
+ scf.yield %arg7 : tensor <128 x32 xbf16 >
147
+ }
148
+ /// CHECK: scf.yield %[[MATMUL_OUT]], %[[IF_RESULT]] :
149
+ scf.yield %16 , %19 : tensor <128 x32 xf32 >, tensor <128 x32 xbf16 >
150
+ }
151
+ /// CHECK: %[[BROADCAST_OUT:.*]] = linalg.broadcast
152
+ /// CHECK: %[[ADD_OUT:.*]] = linalg.add ins(%[[TMP_RESULT]]#1, %[[BROADCAST_OUT]] :
153
+ /// CHECK: %[[MAX_OUT:.*]] = linalg.max ins(%[[ADD_OUT]],
154
+ %13 = affine.apply #map (%arg3 )
155
+ scf.forall.in_parallel {
156
+ /// CHECK: tensor.parallel_insert_slice %[[MAX_OUT]]
157
+ /// CHECK: tensor.parallel_insert_slice
158
+ /// CHECK: tensor.parallel_insert_slice
159
+ tensor.parallel_insert_slice %12#1 into %arg4 [0 , %13 ] [128 , 32 ] [1 , 1 ] : tensor <128 x32 xbf16 > into tensor <128 x512 xbf16 >
160
+ }
161
+ }
162
+ %4 = tensor.empty () : tensor <128 x512 xbf16 >
163
+ %broadcasted = linalg.broadcast ins (%arg2 : tensor <512 xbf16 >) outs (%4 : tensor <128 x512 xbf16 >) dimensions = [0 ]
164
+ %5 = tensor.empty () : tensor <128 x512 xbf16 >
165
+ %6 = linalg.add ins (%3 , %broadcasted : tensor <128 x512 xbf16 >, tensor <128 x512 xbf16 >) outs (%5 : tensor <128 x512 xbf16 >) -> tensor <128 x512 xbf16 >
166
+ %7 = tensor.empty () : tensor <128 x512 xbf16 >
167
+ %8 = linalg.max ins (%6 , %cst : tensor <128 x512 xbf16 >, tensor <128 x512 xbf16 >) outs (%7 : tensor <128 x512 xbf16 >) -> tensor <128 x512 xbf16 >
168
+ /// CHECK: return %[[FINAL_RESULT]]#2
169
+ return %8 : tensor <128 x512 xbf16 >
170
+ }
171
+ }
172
+
173
+ // -----
174
+
98
175
#map = affine_map <(d0 ) -> (d0 * 128 )>
99
176
module {
100
177
/// CHECK-LABEL: @fuse_multiple_consumer
0 commit comments