Skip to content

[mlir][vector] Generalize folding of ext-contractionOp to other types. #96593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2024

Conversation

raikonenfnu
Copy link
Member

Many state of the art models and quantization operations are now directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines.

Many state of the art models and quantization operations are now
directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we
can emit more performant vector.contracts on codegen pipelines.

Signed-off-by: Stanley Winata <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2024

@llvm/pr-subscribers-mlir-vector

Author: Stanley Winata (raikonenfnu)

Changes

Many state of the art models and quantization operations are now directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines.


Full diff: https://github.com/llvm/llvm-project/pull/96593.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..6dc0e1c1b4bd8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1542,6 +1542,7 @@ struct CanonicalizeContractMatmulToMMT final
 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
 /// This pattern folds the arithmetic extensions into the vector contraction and
 /// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
 struct FoldArithExtIntoContractionOp
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -1549,8 +1550,8 @@ struct FoldArithExtIntoContractionOp
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
 
-    auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
-    auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+    auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+    auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
 
     if (!lhsDefOp || !rhsDefOp) {
       return rewriter.notifyMatchFailure(contractOp,
@@ -1804,7 +1805,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 
 void mlir::vector::populateFoldArithExtensionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+               FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+      patterns.getContext());
 }
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
       %lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
     return %result : vector<[64]x64xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+  %arg0: vector<64x64xi8>,
+  %arg1: vector<64x64xi8>,
+  %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+    %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+    %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+    %result = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"],
+      kind = #vector.kind<add>}
+      %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+    return %result : vector<64x64xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2024

@llvm/pr-subscribers-mlir

Author: Stanley Winata (raikonenfnu)

Changes

Many state of the art models and quantization operations are now directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit more performant vector.contracts on codegen pipelines.


Full diff: https://github.com/llvm/llvm-project/pull/96593.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ea4a02f2f2e77..6dc0e1c1b4bd8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1542,6 +1542,7 @@ struct CanonicalizeContractMatmulToMMT final
 /// Cores, i.e, `mma.sync.*.f32.f16.f16.f32` and `mma.sync.*.f32.bf16.bf16.f32`.
 /// This pattern folds the arithmetic extensions into the vector contraction and
 /// enables the usage of native mixed precision Tensor Core instructions.
+template <typename ExtOp>
 struct FoldArithExtIntoContractionOp
     : public OpRewritePattern<vector::ContractionOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -1549,8 +1550,8 @@ struct FoldArithExtIntoContractionOp
   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
                                 PatternRewriter &rewriter) const override {
 
-    auto lhsDefOp = contractOp.getLhs().getDefiningOp<arith::ExtFOp>();
-    auto rhsDefOp = contractOp.getRhs().getDefiningOp<arith::ExtFOp>();
+    auto lhsDefOp = contractOp.getLhs().getDefiningOp<ExtOp>();
+    auto rhsDefOp = contractOp.getRhs().getDefiningOp<ExtOp>();
 
     if (!lhsDefOp || !rhsDefOp) {
       return rewriter.notifyMatchFailure(contractOp,
@@ -1804,7 +1805,9 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
 
 void mlir::vector::populateFoldArithExtensionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldArithExtIntoContractionOp>(patterns.getContext());
+  patterns.add<FoldArithExtIntoContractionOp<arith::ExtFOp>,
+               FoldArithExtIntoContractionOp<arith::ExtSIOp>>(
+      patterns.getContext());
 }
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
diff --git a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
index 31ae126906f21..6dbde7afbdd33 100644
--- a/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
+++ b/mlir/test/Dialect/Vector/fold-arith-extf-into-vector-contract.mlir
@@ -48,3 +48,25 @@ func.func @fold_arith_extf_into_contract_scalable(
       %lhs_f32, %rhs_f32, %arg2 : vector<[64]x64xf32>, vector<64x64xf32> into vector<[64]x64xf32>
     return %result : vector<[64]x64xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @fold_arith_extsi_into_contract
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<64x64xi8>, %[[ARG1:.*]]: vector<64x64xi8>, %[[ARG2:.*]]: vector<64x64xi32>)
+//  CHECK-NEXT:   %[[R:.+]] = vector.contract
+//  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+//  CHECK-SAME:   %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<64x64xi8>, vector<64x64xi8> into vector<64x64xi32>
+//  CHECK-NEXT:   return %[[R]] : vector<64x64xi32>
+func.func @fold_arith_extsi_into_contract(
+  %arg0: vector<64x64xi8>,
+  %arg1: vector<64x64xi8>,
+  %arg2: vector<64x64xi32>) -> vector<64x64xi32> {
+    %lhs_i32 = arith.extsi %arg0 : vector<64x64xi8> to vector<64x64xi32>
+    %rhs_i32 = arith.extsi %arg1 : vector<64x64xi8> to vector<64x64xi32>
+    %result = vector.contract {
+      indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel", "reduction"],
+      kind = #vector.kind<add>}
+      %lhs_i32, %rhs_i32, %arg2 : vector<64x64xi32>, vector<64x64xi32> into vector<64x64xi32>
+    return %result : vector<64x64xi32>
+}

@raikonenfnu raikonenfnu requested a review from antiagainst June 25, 2024 15:26
@raikonenfnu raikonenfnu merged commit ac1e22f into llvm:main Jun 25, 2024
11 checks passed
kuhar pushed a commit to iree-org/llvm-project that referenced this pull request Jun 27, 2024
llvm#96593)

Many state of the art models and quantization operations are now
directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit
more performant vector.contracts on codegen pipelines.

Signed-off-by: Stanley Winata <[email protected]>
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
llvm#96593)

Many state of the art models and quantization operations are now
directly working on vector.contract on integers.

This commit enables generalizes ext-contraction folding S.T we can emit
more performant vector.contracts on codegen pipelines.

Signed-off-by: Stanley Winata <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants