Skip to content

Commit 5672a7c

Browse files
raikonenfnuAlexisPerry
authored andcommitted
[mlir][vector] Generalize folding of ext-contractionOp to other types. (llvm#96593)
Many state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Signed-off-by: Stanley Winata <[email protected]>
1 parent 21c84d9 commit 5672a7c

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,15 +1552,16 @@ struct CanonicalizeContractMatmulToMMT final
15521552
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
15531553
/// This pattern folds the arithmetic extensions into the vector contraction and
15541554
/// enables the usage of native mixed precision Tensor Core instructions.
1555+
template <typename ExtOp>
15551556
struct FoldArithExtIntoContractionOp
15561557
: public OpRewritePattern<vector::ContractionOp> {
15571558
using OpRewritePattern::OpRewritePattern;
15581559

15591560
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
15601561
PatternRewriter &rewriter) const override {
15611562

1562-
auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
1563-
auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
1563+
auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
1564+
auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
15641565

15651566
if (!lhsDefOp || !rhsDefOp) {
15661567
return rewriter.notifyMatchFailure(contractOp,
@@ -1895,7 +1896,9 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
18951896

18961897
void mlir::vector::populateFoldArithExtensionPatterns(
18971898
RewritePatternSet &patterns) {
1898-
patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
1899+
patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
1900+
FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
1901+
patterns.getContext());
18991902
}
19001903

19011904
void mlir::vector::populateVectorMaskMaterializationPatterns(

mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
4848
%lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
4949
return %result : vector<[64]x64xf32>
5050
}
51+
52+
// -----
53+
54+
// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
55+
// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
56+
// CHECK-NEXT: %[[R:.+]] = vector.contract
57+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
58+
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
59+
// CHECK-NEXT: return %[[R]] : vector<64x64xi32>
60+
func.func @fold_arith_extsi_into_contract(
61+
%arg0: vector<64x64xi8>,
62+
%arg1: vector<64x64xi8>,
63+
%arg2: vector<64x64xi32>) -> vector<64x64xi32> {
64+
%lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
65+
%rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
66+
%result = vector.contract {
67+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
68+
iterator_types = ["parallel", "parallel", "reduction"],
69+
kind = #vector.kind<add>}
70+
%lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
71+
return %result : vector<64x64xi32>
72+
}

0 commit comments

Comments
 (0)