Skip to content

[mlir][Vector] Add support for sub-byte transpose emulation #80110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 1, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.

This PR adds patterns to convert a sub-byte vector transpose into
a sequence of instructions that perform the transpose on i8 vector
elements. Whereas this rewrite may not lead to the absolute peak
performance, it should ensure correctness when dealing with sub-byte
transposes.
@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.


Full diff: https://github.com/llvm/llvm-project/pull/80110.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+4)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+51)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb938..ce88360aa52e9 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_masked_transfers",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Apply opt-in patterns that lower vector.mask operations surrounding 
+    Apply opt-in patterns that lower vector.mask operations surrounding
     side-effecting ops:
       - MaskedTransferReadOpPattern
       - MaskedTransferWriteOpPattern
@@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
       - ReorderCastOpsOnBroadcast
       - ReorderElementwiseOpsOnTranspose
 
-    These patterns have the effect of rewriting a vector.multi_reduce into a 
+    These patterns have the effect of rewriting a vector.multi_reduce into a
     vector.contract.
   }];
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 49b74c0c466d2..f5941d32e683f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
                                              PatternBenefit benefit = 1);
 
+/// Appends patterns for emulating a sub-byte vector transpose.
+void populateVectorTransposeNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e5..19922c4295fe0 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
 void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorNarrowTypeRewritePatterns(patterns);
+  populateVectorTransposeNarrowTypeRewritePatterns(patterns);
 }
 
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0110a8df89aee..193c9a6182b49 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1052,6 +1052,52 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite a sub-byte vector transpose into a sequence of instructions that
+/// perform the transpose on wider (byte) element types.
+/// For example:
+///   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+///
+///   is rewritten as:
+///
+///   %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
+///   %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+///   %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
+///
+struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    // Precondition: sub-byte integer transpose.
+    constexpr unsigned minNativeBitwidth = 8;
+    VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
+    if (srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth)
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         "not a sub-byte transpose");
+
+    // Perform the rewrite.
+    Location loc = transposeOp.getLoc();
+    // Signed/unsigned interpretation shouldn't matter here as we are just
+    // transposing the elements and truncating them back to the original size.
+    // TODO: Use unsigned extension (more efficient) when emulation or backend
+    // support is available.
+    auto srcNativeVecType =
+        srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
+                                                  transposeOp.getVector());
+    Value newTranspose = rewriter.create<vector::TransposeOp>(
+        loc, extOp, transposeOp.getPermutation());
+    VectorType dstSubByteVecType = transposeOp.getResultVectorType();
+    rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
+                                                 newTranspose);
+    return success();
+  }
+};
+
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1080,3 +1126,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
                RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
       patterns.getContext(), benefit.getBenefit() + 1);
 }
+
+void vector::populateVectorTransposeNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index c4fbb4c219b91..02063a81664b8 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @i4_transpose(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
+  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
+  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+  return %0 : vector<16x8xi4>
+}
+
+// CHECK-LABEL: func.func @i7_transpose(
+//  CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
+  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
+  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+  %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
+  return %0 : vector<16x8xi7>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op
@@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+

Copy link

github-actions bot commented Jan 31, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

// TODO: Use unsigned extension (more efficient) when emulation or backend
// support is available.
auto srcNativeVecType =
srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rewriter.getIntegerType(minNativeBitwidth) is better for consistency.

@dcaballe dcaballe requested review from pzread, bviyer and Max191 January 31, 2024 19:46
@dcaballe dcaballe merged commit 8ba018d into llvm:main Feb 1, 2024
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

carlosgalvezp pushed a commit to carlosgalvezp/llvm-project that referenced this pull request Feb 1, 2024
This PR adds patterns to convert a sub-byte vector transpose into a
sequence of instructions that perform the transpose on i8 vector
elements. Whereas this rewrite may not lead to the absolute peak
performance, it should ensure correctness when dealing with sub-byte
transposes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants