Skip to content

Commit b8fb65d

Browse files
committed
[mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation
This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.
1 parent 19b65a9 commit b8fb65d

File tree

2 files changed

+189
-20
lines changed

2 files changed

+189
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 156 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -642,19 +642,19 @@ struct BitCastRewriter {
642642

643643
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
644644

645-
/// Verify that the preconditions for the rewrite are met.
646-
LogicalResult precondition(PatternRewriter &rewriter,
647-
VectorType preconditionVectorType, Operation *op);
645+
/// Verify that general preconditions for the rewrite are met.
646+
LogicalResult commonPrecondition(PatternRewriter &rewriter,
647+
VectorType preconditionType, Operation *op);
648648

649649
/// Precompute the metadata for the rewrite.
650650
SmallVector<BitCastRewriter::Metadata>
651651
precomputeMetadata(IntegerType shuffledElementType);
652652

653653
/// Rewrite one step of the sequence:
654654
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655-
Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
656-
Value runningResult,
657-
const BitCastRewriter::Metadata &metadata);
655+
Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
656+
Value initialValue, Value runningResult,
657+
const BitCastRewriter::Metadata &metadata);
658658

659659
private:
660660
/// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719719
LDBG("\n" << enumerator.sourceElementRanges);
720720
}
721721

722-
LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
723-
VectorType precondition,
724-
Operation *op) {
725-
if (precondition.getRank() != 1 || precondition.isScalable())
722+
/// Verify that the precondition type meets the common preconditions for any
723+
/// conversion.
724+
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
725+
VectorType preconditionType,
726+
Operation *op) {
727+
if (!preconditionType || preconditionType.getRank() != 1 ||
728+
preconditionType.isScalable())
726729
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
727730

728731
// TODO: consider relaxing this restriction in the future if we find ways
729732
// to really work with subbyte elements across the MLIR/LLVM boundary.
730-
int64_t resultBitwidth = precondition.getElementTypeBitWidth();
733+
unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
731734
if (resultBitwidth % 8 != 0)
732735
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
733736

734737
return success();
735738
}
736739

740+
LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
741+
VectorType preconditionType,
742+
Operation *op) {
743+
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
744+
return rewriter.notifyMatchFailure(op, "types are not vector");
745+
746+
return commonConversionPrecondition(rewriter, preconditionType, op);
747+
}
748+
749+
/// Verify that source and destination element types meet the precondition for
750+
/// the supported aligned conversion cases. Alignment means that the either the
751+
/// source element type is multiple of the destination element type or the other
752+
/// way around.
753+
///
754+
/// NOTE: This method assumes that common conversion preconditions are met.
755+
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
756+
VectorType srcType,
757+
VectorType dstType,
758+
Operation *op) {
759+
if (!srcType || !dstType)
760+
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
761+
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
762+
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
763+
unsigned byteBitwidth = 8;
764+
765+
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
766+
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
767+
(dstElemBitwidth % srcElemBitwidth) != 0)
768+
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
769+
770+
return success();
771+
}
772+
737773
SmallVector<BitCastRewriter::Metadata>
738774
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
739775
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775811
return result;
776812
}
777813

778-
Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
779-
Value initialValue, Value runningResult,
780-
const BitCastRewriter::Metadata &metadata) {
814+
Value BitCastRewriter::genericRewriteStep(
815+
PatternRewriter &rewriter, Location loc, Value initialValue,
816+
Value runningResult, const BitCastRewriter::Metadata &metadata) {
781817
// Create vector.shuffle from the metadata.
782818
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
783819
loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810846
return runningResult;
811847
}
812848

849+
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
850+
/// bitwise ops that take advantage of high-level information to avoid leaving
851+
/// LLVM to scramble with peephole optimizations.
852+
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
853+
Value srcValue) {
854+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
855+
assert(srcVecType.getElementType().isSignlessInteger(4) &&
856+
"Expected i4 type");
857+
858+
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
859+
int64_t vecDimSize = srcVecType.getShape().back();
860+
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
861+
constexpr int64_t i4Toi8BitwidthFactor = 2;
862+
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
863+
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
864+
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
865+
866+
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
867+
// byte are place in one vector and the high i4 elements in another vector.
868+
constexpr int8_t bitsToShift = 4;
869+
auto shiftValues = rewriter.create<arith::ConstantOp>(
870+
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
871+
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
872+
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
873+
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
874+
875+
// 3. Interleave low and high i8 elements using a shuffle.
876+
SmallVector<int64_t> interleaveMaskValues;
877+
interleaveMaskValues.reserve(vecDimSize);
878+
for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
879+
interleaveMaskValues.push_back(i);
880+
interleaveMaskValues.push_back(i + (vecDimSize / 2));
881+
}
882+
883+
return rewriter.create<vector::ShuffleOp>(
884+
loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
885+
}
886+
813887
namespace {
814888
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815889
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829903
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
830904
VectorType targetVectorType = bitCastOp.getResultVectorType();
831905
BitCastRewriter bcr(sourceVectorType, targetVectorType);
832-
if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
906+
if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
833907
return failure();
834908

835909
// Perform the rewrite.
@@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839913
Value runningResult;
840914
for (const BitCastRewriter ::Metadata &metadata :
841915
bcr.precomputeMetadata(shuffledElementType)) {
842-
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
843-
runningResult, metadata);
916+
runningResult = bcr.genericRewriteStep(
917+
rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
844918
}
845919

846920
// Finalize the rewrite.
@@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
893967
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
894968
VectorType targetVectorType = bitCastOp.getResultVectorType();
895969
BitCastRewriter bcr(sourceVectorType, targetVectorType);
896-
if (failed(bcr.precondition(
970+
if (failed(bcr.commonPrecondition(
897971
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
898972
return failure();
899973

@@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
904978
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
905979
for (const BitCastRewriter::Metadata &metadata :
906980
bcr.precomputeMetadata(shuffledElementType)) {
907-
runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
908-
sourceValue, runningResult, metadata);
981+
runningResult = bcr.genericRewriteStep(
982+
rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
909983
}
910984

911985
// Finalize the rewrite.
@@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
923997
return success();
924998
}
925999
};
1000+
1001+
/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1002+
/// bitwise ops that take advantage of high-level information to avoid leaving
1003+
/// LLVM to scramble with peephole optimizations.
1004+
///
1005+
/// For example:
1006+
/// arith.extsi %in : vector<8xi4> to vector<8xi32>
1007+
/// is rewriten as
1008+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1009+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1010+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1011+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1012+
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1013+
/// : vector<4xi8>, vector<4xi8>
1014+
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1015+
///
1016+
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
1017+
/// is rewriten as
1018+
/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1019+
/// %1 = arith.shli %0, 4 : vector<4xi8>
1020+
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
1021+
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1022+
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1023+
/// : vector<4xi8>, vector<4xi8>
1024+
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1025+
///
1026+
template <typename ConversionOpType>
1027+
struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1028+
using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1029+
1030+
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
1031+
PatternRewriter &rewriter) const override {
1032+
// Set up the BitCastRewriter and verify the preconditions.
1033+
Value srcValue = conversionOp.getIn();
1034+
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1035+
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
1036+
if (failed(
1037+
commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
1038+
return failure();
1039+
1040+
// Check general alignment preconditions.
1041+
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1042+
conversionOp)))
1043+
return failure();
1044+
1045+
// Perform the rewrite.
1046+
Value subByteExt =
1047+
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1048+
1049+
// Finalize the rewrite.
1050+
rewriter.replaceOpWithNewOp<ConversionOpType>(
1051+
conversionOp, conversionOp.getType(), subByteExt);
1052+
return success();
1053+
}
1054+
};
1055+
9261056
} // namespace
9271057

9281058
//===----------------------------------------------------------------------===//
@@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
9441074
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
9451075
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
9461076
benefit);
1077+
1078+
// Patterns for aligned cases. We set higher priority as they are expected to
1079+
// generate better performance for aligned cases.
1080+
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1081+
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1082+
patterns.getContext(), benefit.getBenefit() + 1);
9471083
}

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
193193
return %1 : vector<8xi17>
194194
}
195195

196+
// CHECK-LABEL: func.func @aligned_extsi(
197+
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
198+
// CHECK: arith.shli
199+
// CHECK: arith.shrsi
200+
// CHECK: arith.shrsi
201+
// CHECK: vector.shuffle
202+
// CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
203+
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
204+
return %0 : vector<8xi32>
205+
}
206+
207+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
208+
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
209+
// CHECK: arith.shli
210+
// CHECK: arith.shrsi
211+
// CHECK: arith.shrsi
212+
// CHECK: vector.shuffle
213+
// CHECK-NOT: arith.extsi
214+
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
215+
return %0 : vector<8xi8>
216+
}
217+
218+
// CHECK-LABEL: func.func @aligned_sitofp(
219+
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
220+
// CHECK: arith.shli
221+
// CHECK: arith.shrsi
222+
// CHECK: arith.shrsi
223+
// CHECK: shuffle
224+
// CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
225+
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
226+
return %0 : vector<8xf32>
227+
}
228+
196229
module attributes {transform.with_named_sequence} {
197230
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
198231
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)