Skip to content

Commit 8ba018d

Browse files
authored
[mlir][Vector] Add support for sub-byte transpose emulation (llvm#80110)
This PR adds patterns to convert a sub-byte vector transpose into a sequence of instructions that perform the transpose on i8 vector elements. Whereas this rewrite may not lead to the absolute peak performance, it should ensure correctness when dealing with sub-byte transposes.
1 parent 0e8eb44 commit 8ba018d

File tree

5 files changed

+80
-2
lines changed

5 files changed

+80
-2
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def ApplyLowerMaskedTransfersPatternsOp : Op<Transform_Dialect,
151151
"apply_patterns.vector.lower_masked_transfers",
152152
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
153153
let description = [{
154-
Apply opt-in patterns that lower vector.mask operations surrounding
154+
Apply opt-in patterns that lower vector.mask operations surrounding
155155
side-effecting ops:
156156
- MaskedTransferReadOpPattern
157157
- MaskedTransferWriteOpPattern
@@ -376,7 +376,7 @@ def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
376376
- ReorderCastOpsOnBroadcast
377377
- ReorderElementwiseOpsOnTranspose
378378

379-
These patterns have the effect of rewriting a vector.multi_reduce into a
379+
These patterns have the effect of rewriting a vector.multi_reduce into a
380380
vector.contract.
381381
}];
382382

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
371371
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
372372
PatternBenefit benefit = 1);
373373

374+
/// Appends patterns for emulating a sub-byte vector transpose.
375+
void populateVectorTransposeNarrowTypeRewritePatterns(
376+
RewritePatternSet &patterns, PatternBenefit benefit = 1);
377+
374378
} // namespace vector
375379
} // namespace mlir
376380

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
162162
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
163163
RewritePatternSet &patterns) {
164164
populateVectorNarrowTypeRewritePatterns(patterns);
165+
populateVectorTransposeNarrowTypeRewritePatterns(patterns);
165166
}
166167

167168
void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,53 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10521052
}
10531053
};
10541054

1055+
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
1056+
/// perform the transpose on wider (byte) element types.
1057+
/// For example:
1058+
/// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
1059+
///
1060+
/// is rewritten as:
1061+
///
1062+
/// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8>
1063+
/// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8>
1064+
/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
1065+
///
1066+
struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
1067+
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
1068+
1069+
RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
1070+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
1071+
1072+
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
1073+
PatternRewriter &rewriter) const override {
1074+
// Precondition: sub-byte integer transpose.
1075+
constexpr unsigned minNativeBitwidth = 8;
1076+
VectorType srcSubByteVecType = transposeOp.getSourceVectorType();
1077+
if (!srcSubByteVecType.getElementType().isSignlessInteger() ||
1078+
srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) {
1079+
return rewriter.notifyMatchFailure(transposeOp,
1080+
"not a sub-byte transpose");
1081+
}
1082+
1083+
// Perform the rewrite.
1084+
Location loc = transposeOp.getLoc();
1085+
// Signed/unsigned interpretation shouldn't matter here as we are just
1086+
// transposing the elements and truncating them back to the original size.
1087+
// TODO: Use unsigned extension (more efficient) when emulation or backend
1088+
// support is available.
1089+
auto srcNativeVecType = srcSubByteVecType.cloneWith(
1090+
std::nullopt, rewriter.getIntegerType(minNativeBitwidth));
1091+
Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType,
1092+
transposeOp.getVector());
1093+
Value newTranspose = rewriter.create<vector::TransposeOp>(
1094+
loc, extOp, transposeOp.getPermutation());
1095+
VectorType dstSubByteVecType = transposeOp.getResultVectorType();
1096+
rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType,
1097+
newTranspose);
1098+
return success();
1099+
}
1100+
};
1101+
10551102
} // namespace
10561103

10571104
//===----------------------------------------------------------------------===//
@@ -1080,3 +1127,8 @@ void vector::populateVectorNarrowTypeRewritePatterns(
10801127
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
10811128
patterns.getContext(), benefit.getBenefit() + 1);
10821129
}
1130+
1131+
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
1132+
RewritePatternSet &patterns, PatternBenefit benefit) {
1133+
patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit);
1134+
}

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,26 @@ func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
226226
return %0 : vector<8xf32>
227227
}
228228

229+
// CHECK-LABEL: func.func @i4_transpose(
230+
// CHECK-SAME: %[[A:[0-9a-z]*]]
231+
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
232+
// CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
233+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
234+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
235+
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
236+
return %0 : vector<16x8xi4>
237+
}
238+
239+
// CHECK-LABEL: func.func @i7_transpose(
240+
// CHECK-SAME: %[[A:[0-9a-z]*]]
241+
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
242+
// CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
243+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
244+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
245+
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
246+
return %0 : vector<16x8xi7>
247+
}
248+
229249
module attributes {transform.with_named_sequence} {
230250
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
231251
%f = transform.structured.match ops{["func.func"]} in %module_op
@@ -237,3 +257,4 @@ module attributes {transform.with_named_sequence} {
237257
transform.yield
238258
}
239259
}
260+

0 commit comments

Comments
 (0)