@@ -151,43 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
151
151
iterator_types = [" parallel" , " parallel" , " reduction" ]
152
152
}
153
153
154
- // CHECK-LABEL: func @extract_contract4
155
- // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
156
- // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
157
- // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
158
- // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
159
- // CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
160
- // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
161
- // CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
162
- // CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
163
- // CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
164
- // CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
154
+ // CHECK-LABEL: func @contract_to_dot_matmat
155
+ // CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
156
+ // CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
157
+ // CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
165
158
//
166
- // CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
167
- // CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
168
- // CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
169
- // CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
159
+ // The `vector.contract` to dot lowering will 'unroll' a matrix-matrix
160
+ // multiplication into individual dot products betweem rows of the LHS with columns
161
+ // of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of
162
+ // ops that correspond to the 4 dot products resulting from unrolling a matmul between
163
+ // two matrices of size (2, 2).
170
164
//
171
- // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
172
- // CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
173
- // CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
174
- // CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
175
- // CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
165
+ // CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
176
166
//
177
- // CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
178
- // CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
179
- // CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
180
- // CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
167
+ // First, The RHS will be transposed to make it easier to extract individual columns
168
+ // using vector.extract.
181
169
//
182
- // CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
183
- // CHECK: return %[[T52]] : vector<2x2xf32>
170
+ // CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
171
+ //
172
+ // Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot
173
+ // product and then inserting it into the result.
174
+ //
175
+ // CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
176
+ // CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
177
+ // CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32>
178
+ // CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] : vector<2xf32> into f32
179
+ // CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
180
+ //
181
+ // CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
182
+ // CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32>
183
+ // CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] : vector<2xf32> into f32
184
+ // CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
185
+ //
186
+ // CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
187
+ // CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32>
188
+ // CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] : vector<2xf32> into f32
189
+ // CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
190
+ //
191
+ // CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32>
192
+ // CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] : vector<2xf32> into f32
193
+ // CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
194
+ //
195
+ // CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32>
196
+ // CHECK: return %[[RES]] : vector<2x2xf32>
184
197
185
- func.func @extract_contract4 ( %arg0 : vector <2 x2 xf32 >,
186
- %arg1 : vector <2 x2 xf32 >,
187
- %arg2 : vector <2 x2 xf32 >) -> vector <2 x2 xf32 > {
188
- %0 = vector.contract #matmat_trait %arg0 , %arg1 , %arg2
198
+ func.func @contract_to_dot_matmat ( %lhs : vector <2 x2 xf32 >,
199
+ %rhs : vector <2 x2 xf32 >,
200
+ %init : vector <2 x2 xf32 >) -> vector <2 x2 xf32 > {
201
+ %res = vector.contract #matmat_trait %lhs , %rhs , %init
189
202
: vector <2 x2 xf32 >, vector <2 x2 xf32 > into vector <2 x2 xf32 >
190
- return %0 : vector <2 x2 xf32 >
203
+ return %res : vector <2 x2 xf32 >
191
204
}
192
205
193
206
0 commit comments