Skip to content

Commit fc127ff

Browse files
authored
[mlir] Extract RHS rows once when lowering vector.contract to dot (llvm#130130)
The `vector.contract` op on two matrices A and B will be lowered to individual dot products of each row and column of A and B respectively. The existing lowering will extract each column of B for each row of A, which leads to multiple values in the IR representing the same columns of B. This PR makes changes to the `ContractOpToDotLowering` to make sure that the columns of B are only ever extracted once, so then the SSA values representing the extracted columns are then re-used in the IR for later dot products. I have updated the existing vector-contract-to-dot-transforms test.
1 parent 1a626e6 commit fc127ff

File tree

2 files changed

+61
-39
lines changed

2 files changed

+61
-39
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -757,19 +757,28 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
757757
Value res = rewriter.create<arith::ConstantOp>(loc, dstType,
758758
rewriter.getZeroAttr(dstType));
759759
bool isInt = isa<IntegerType>(dstType.getElementType());
760+
llvm::SmallVector<Value> extractedCols;
761+
extractedCols.reserve(dstColumns);
760762
for (unsigned r = 0; r < dstRows; ++r) {
761-
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
763+
Value rowLhs = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, r);
762764
for (unsigned c = 0; c < dstColumns; ++c) {
763-
Value b = rank == 1
764-
? rhs
765-
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
766-
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
767-
Value reduced = rewriter.create<vector::ReductionOp>(
768-
op.getLoc(), vector::CombiningKind::ADD, m);
765+
// Extract each respective row and column of the LHS and RHS once to
766+
// avoid having duplicate SSA values pointing to the same rows/columns.
767+
if (r == 0) {
768+
Value colRhs =
769+
rank == 1 ? rhs
770+
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
771+
extractedCols.push_back(colRhs);
772+
}
773+
Value extractedColRhs = extractedCols[c];
774+
Value product =
775+
createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
776+
Value sum = rewriter.create<vector::ReductionOp>(
777+
op.getLoc(), vector::CombiningKind::ADD, product);
769778

770779
SmallVector<int64_t, 2> pos = rank == 1 ? SmallVector<int64_t, 2>{r}
771780
: SmallVector<int64_t, 2>{r, c};
772-
res = rewriter.create<vector::InsertOp>(op.getLoc(), reduced, res, pos);
781+
res = rewriter.create<vector::InsertOp>(op.getLoc(), sum, res, pos);
773782
}
774783
}
775784
if (auto acc = op.getAcc())

mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -151,43 +151,56 @@ func.func @extract_contract3(%arg0: vector<3xf32>,
151151
iterator_types = ["parallel", "parallel", "reduction"]
152152
}
153153

154-
// CHECK-LABEL: func @extract_contract4
155-
// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>,
156-
// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>,
157-
// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32>
158-
// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
159-
// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32>
160-
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
161-
// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
162-
// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32>
163-
// CHECK: %[[T10:.*]] = vector.reduction <add>, %[[T9]] : vector<2xf32> into f32
164-
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32>
154+
// CHECK-LABEL: func @contract_to_dot_matmat
155+
// CHECK-SAME: %[[LHS:.*0]]: vector<2x2xf32>,
156+
// CHECK-SAME: %[[RHS:.*1]]: vector<2x2xf32>,
157+
// CHECK-SAME: %[[OUT:.*2]]: vector<2x2xf32>
165158
//
166-
// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
167-
// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32>
168-
// CHECK: %[[T20:.*]] = vector.reduction <add>, %[[T19]] : vector<2xf32> into f32
169-
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32>
159+
// The `vector.contract` to dot lowering will 'unroll' a matrix-matrix
160+
// multiplication into individual dot products betweem rows of the LHS with columns
161+
// of the RHS. In the following test we expect 4 extract-dotproduct-insert sequences of
162+
// ops that correspond to the 4 dot products resulting from unrolling a matmul between
163+
// two matrices of size (2, 2).
170164
//
171-
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
172-
// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2xf32> from vector<2x2xf32>
173-
// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32>
174-
// CHECK: %[[T33:.*]] = vector.reduction <add>, %[[T32]] : vector<2xf32> into f32
175-
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32>
165+
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
176166
//
177-
// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2xf32> from vector<2x2xf32>
178-
// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32>
179-
// CHECK: %[[T42:.*]] = vector.reduction <add>, %[[T41]] : vector<2xf32> into f32
180-
// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32>
167+
// First, The RHS will be transposed to make it easier to extract individual columns
168+
// using vector.extract.
181169
//
182-
// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32>
183-
// CHECK: return %[[T52]] : vector<2x2xf32>
170+
// CHECK: %[[RHS_T:.*]] = vector.transpose %[[RHS]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
171+
//
172+
// Next, we expect 4 sequences of extracting rows of the RHS, LHS, performing a dot
173+
// product and then inserting it into the result.
174+
//
175+
// CHECK: %[[LHS0:.*]] = vector.extract %[[LHS]][0] : vector<2xf32> from vector<2x2xf32>
176+
// CHECK: %[[RHS_T0:.*]] = vector.extract %[[RHS_T]][0] : vector<2xf32> from vector<2x2xf32>
177+
// CHECK: %[[PROD0:.*]] = arith.mulf %[[LHS0]], %[[RHS_T0]] : vector<2xf32>
178+
// CHECK: %[[SUM0:.*]] = vector.reduction <add>, %[[PROD0]] : vector<2xf32> into f32
179+
// CHECK: %[[RES0:.*]] = vector.insert %[[SUM0]], %[[INIT]] [0, 0] : f32 into vector<2x2xf32>
180+
//
181+
// CHECK: %[[RHS_T1:.*]] = vector.extract %[[RHS_T]][1] : vector<2xf32> from vector<2x2xf32>
182+
// CHECK: %[[PROD1:.*]] = arith.mulf %[[LHS0]], %[[RHS_T1]] : vector<2xf32>
183+
// CHECK: %[[SUM1:.*]] = vector.reduction <add>, %[[PROD1]] : vector<2xf32> into f32
184+
// CHECK: %[[RES1:.*]] = vector.insert %[[SUM1]], %[[RES0]] [0, 1] : f32 into vector<2x2xf32>
185+
//
186+
// CHECK: %[[LHS1:.*]] = vector.extract %[[LHS]][1] : vector<2xf32> from vector<2x2xf32>
187+
// CHECK: %[[PROD2:.*]] = arith.mulf %[[LHS1]], %[[RHS_T0]] : vector<2xf32>
188+
// CHECK: %[[SUM2:.*]] = vector.reduction <add>, %[[PROD2]] : vector<2xf32> into f32
189+
// CHECK: %[[RES2:.*]] = vector.insert %[[SUM2]], %[[RES1]] [1, 0] : f32 into vector<2x2xf32>
190+
//
191+
// CHECK: %[[PROD3:.*]] = arith.mulf %[[LHS1]], %[[RHS_T1]] : vector<2xf32>
192+
// CHECK: %[[SUM3:.*]] = vector.reduction <add>, %[[PROD3]] : vector<2xf32> into f32
193+
// CHECK: %[[RES3:.*]] = vector.insert %[[SUM3]], %[[RES2]] [1, 1] : f32 into vector<2x2xf32>
194+
//
195+
// CHECK: %[[RES:.*]] = arith.addf %[[RES3]], %[[OUT]] : vector<2x2xf32>
196+
// CHECK: return %[[RES]] : vector<2x2xf32>
184197

185-
func.func @extract_contract4(%arg0: vector<2x2xf32>,
186-
%arg1: vector<2x2xf32>,
187-
%arg2: vector<2x2xf32>) -> vector<2x2xf32> {
188-
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
198+
func.func @contract_to_dot_matmat(%lhs: vector<2x2xf32>,
199+
%rhs: vector<2x2xf32>,
200+
%init: vector<2x2xf32>) -> vector<2x2xf32> {
201+
%res = vector.contract #matmat_trait %lhs, %rhs, %init
189202
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
190-
return %0 : vector<2x2xf32>
203+
return %res : vector<2x2xf32>
191204
}
192205

193206

0 commit comments

Comments
 (0)