-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add support for sub-byte transpose emulation #80110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes. Full diff: https://github.com/llvm/llvm-project/pull/80110.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 3ac6f28dcb938..ce88360aa52e9 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_masked_transfers",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
- Apply opt-in patterns that lower vector.mask operations surrounding
+ Apply opt-in patterns that lower vector.mask operations surrounding
side-effecting ops:
- MaskedTransferReadOpPattern
- MaskedTransferWriteOpPattern
@@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
- ReorderCastOpsOnBroadcast
- ReorderElementwiseOpsOnTranspose
- These patterns have the effect of rewriting a vector.multi_reduce into a
+ These patterns have the effect of rewriting a vector.multi_reduce into a
vector.contract.
}];
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 49b74c0c466d2..f5941d32e683f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);
+/// Appends patterns for emulating a sub-byte vector transpose.
+void populateVectorTransposeNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 37127ea70f1e5..19922c4295fe0 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
+ populateVectorTransposeNarrowTypeRewritePatterns(patterns);
}
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 0110a8df89aee..193c9a6182b49 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1052,6 +1052,52 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite a sub-byte vector transpose into a sequence of instructions that
+/// perform the transpose on wider (byte) element types.
+/// For example:
+/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+///
+/// is rewritten as:
+///
+/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
+/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
+///
+struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+ RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
+ : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
+
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ // Precondition: sub-byte integer transpose.
+ constexpr unsigned minNativeBitwidth = 8;
+ VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
+ if (srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth)
+ return rewriter.notifyMatchFailure(transposeOp,
+ "not a sub-byte transpose");
+
+ // Perform the rewrite.
+ Location loc = transposeOp.getLoc();
+ // Signed/unsigned interpretation shouldn't matter here as we are just
+ // transposing the elements and truncating them back to the original size.
+ // TODO: Use unsigned extension (more efficient) when emulation or backend
+ // support is available.
+ auto srcNativeVecType =
+ srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+ Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
+ transposeOp.getVector());
+ Value newTranspose = rewriter.create<vector::TransposeOp>(
+ loc, extOp, transposeOp.getPermutation());
+ VectorType dstSubByteVecType = transposeOp.getResultVectorType();
+ rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
+ newTranspose);
+ return success();
+ }
+};
+
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1080,3 +1126,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
patterns.getContext(), benefit.getBenefit() + 1);
}
+
+void vector::populateVectorTransposeNarrowTypeRewritePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index c4fbb4c219b91..02063a81664b8 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @i4_transpose(
+// CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
+ // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
+ // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+ // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+ %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
+ return %0 : vector<16x8xi4>
+}
+
+// CHECK-LABEL: func.func @i7_transpose(
+// CHECK-SAME: %[[A:[0-9a-z]*]]
+func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
+ // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
+ // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+ // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+ %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
+ return %0 : vector<16x8xi7>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
// TODO: Use unsigned extension (more efficient) when emulation or backend | ||
// support is available. | ||
auto srcNativeVecType = | ||
srcSubByteVecType.cloneWith(std::nullopt, rewriter.getI8Type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think rewriter.getIntegerType(minNativeBitwidth)
is better for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.
This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.