-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Generalize folding of ext-contractionOp to other types. #96593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Generalize folding of ext-contractionOp to other types. #96593
Conversation
Many state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Signed-off-by: Stanley Winata <[email protected]>
@llvm/pr-subscribers-mlir-vector Author: Stanley Winata (raikonenfnu) ChangesMany state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Full diff: https://github.com/llvm/llvm-project/pull/96593.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..6dc0e1c1b4bd8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1542,6 +1542,7 @@ struct CanonicalizeContractMatmulToMMT final
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
/// This pattern folds the arithmetic extensions into the vector contraction and
/// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -1549,8 +1550,8 @@ struct FoldArithExtIntoContractionOp
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
- auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+ auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+ auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
if (!lhsDefOp || !rhsDefOp) {
return rewriter.notifyMatchFailure(contractOp,
@@ -1804,7 +1805,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
void mlir::vector::populateFoldArithExtensionPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+ patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+ FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+ patterns.getContext());
}
void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
%lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
return %result : vector<[64]x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+// CHECK-NEXT: %[[R:.+]] = vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+// CHECK-NEXT: return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+ %arg0: vector<64x64xi8>,
+ %arg1: vector<64x64xi8>,
+ %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+ %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+ %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+ %result = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+ return %result : vector<64x64xi32>
+}
|
@llvm/pr-subscribers-mlir Author: Stanley Winata (raikonenfnu) ChangesMany state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Full diff: https://github.com/llvm/llvm-project/pull/96593.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..6dc0e1c1b4bd8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1542,6 +1542,7 @@ struct CanonicalizeContractMatmulToMMT final
/// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
/// This pattern folds the arithmetic extensions into the vector contraction and
/// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
@@ -1549,8 +1550,8 @@ struct FoldArithExtIntoContractionOp
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
- auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+ auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+ auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
if (!lhsDefOp || !rhsDefOp) {
return rewriter.notifyMatchFailure(contractOp,
@@ -1804,7 +1805,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
void mlir::vector::populateFoldArithExtensionPatterns(
RewritePatternSet &patterns) {
- patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+ patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+ FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+ patterns.getContext());
}
void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
%lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
return %result : vector<[64]x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+// CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+// CHECK-NEXT: %[[R:.+]] = vector.contract
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+// CHECK-NEXT: return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+ %arg0: vector<64x64xi8>,
+ %arg1: vector<64x64xi8>,
+ %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+ %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+ %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+ %result = vector.contract {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>}
+ %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+ return %result : vector<64x64xi32>
+}
|
llvm#96593) Many state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Signed-off-by: Stanley Winata <[email protected]>
llvm#96593) Many state of the art models and quantization operations are now directly working on vector.contract on integers. This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines. Signed-off-by: Stanley Winata <[email protected]>
Many state of the art models and quantization operations are now directly working on vector.contract on integers.
This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines.