Skip to content

Commit 386aa7b

Browse files
authored
[mlir][Vector] Replace vector.shuffle with vector.interleave in vector narrow type emulation (#82550)
This PR replaces the generation of `vector.shuffle` with `vector.interleave` in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of `vector.interleave` allow us to enable these conversion emulations also for multi dimensional vectors.
1 parent 0e8d187 commit 386aa7b

File tree

2 files changed

+68
-41
lines changed

2 files changed

+68
-41
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
724724
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
725725
VectorType preconditionType,
726726
Operation *op) {
727-
if (!preconditionType || preconditionType.getRank() != 1 ||
728-
preconditionType.isScalable())
729-
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
727+
if (!preconditionType || preconditionType.isScalable())
728+
return rewriter.notifyMatchFailure(op, "scalable vector");
730729

731730
// TODO: consider relaxing this restriction in the future if we find ways
732731
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
743742
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
744743
return rewriter.notifyMatchFailure(op, "types are not vector");
745744

745+
if (!preconditionType || preconditionType.getRank() != 1)
746+
return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
747+
746748
return commonConversionPrecondition(rewriter, preconditionType, op);
747749
}
748750

@@ -855,7 +857,6 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
855857
"Expected i4 type");
856858

857859
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
858-
int64_t vecDimSize = srcVecType.getShape().back();
859860
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
860861
constexpr int64_t i4Toi8BitwidthFactor = 2;
861862
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
@@ -871,16 +872,8 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
871872
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
872873
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
873874

874-
// 3. Interleave low and high i8 elements using a shuffle.
875-
SmallVector<int64_t> interleaveMaskValues;
876-
interleaveMaskValues.reserve(vecDimSize);
877-
for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
878-
interleaveMaskValues.push_back(i);
879-
interleaveMaskValues.push_back(i + (vecDimSize / 2));
880-
}
881-
882-
return rewriter.create<vector::ShuffleOp>(
883-
loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
875+
// 3. Interleave low and high i8 elements.
876+
return rewriter.create<vector::InterleaveOp>(loc, low, high);
884877
}
885878

886879
namespace {
@@ -1008,8 +1001,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10081001
/// %1 = arith.shli %0, 4 : vector<4xi8>
10091002
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
10101003
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1011-
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1012-
/// : vector<4xi8>, vector<4xi8>
1004+
/// %4 = vector.interleave %2, %3 : vector<4xi8>
10131005
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
10141006
///
10151007
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1010,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
10181010
/// %1 = arith.shli %0, 4 : vector<4xi8>
10191011
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
10201012
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
1021-
/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1022-
/// : vector<4xi8>, vector<4xi8>
1013+
/// %4 = vector.interleave %2, %3 : vector<4xi8>
10231014
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
10241015
///
10251016
template <typename ConversionOpType>

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
195195

196196
// CHECK-LABEL: func.func @aligned_extsi(
197197
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
198-
// CHECK: arith.shli
199-
// CHECK: arith.shrsi
200-
// CHECK: arith.shrsi
201-
// CHECK: vector.shuffle
202-
// CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
198+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
199+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
200+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
201+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
202+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
203+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
204+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
205+
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
203206
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
204207
return %0 : vector<8xi32>
205208
}
206209

210+
// CHECK-LABEL: func.func @aligned_extsi_2d(
211+
func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
212+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
213+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
214+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
215+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
216+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
217+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
218+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
219+
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
220+
%0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
221+
return %0 : vector<8x32xi32>
222+
}
223+
207224
// CHECK-LABEL: func.func @aligned_extsi_base_case(
208225
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
209-
// CHECK: arith.shli
210-
// CHECK: arith.shrsi
211-
// CHECK: arith.shrsi
212-
// CHECK: vector.shuffle
213-
// CHECK-NOT: arith.extsi
226+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
227+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
228+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
229+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
230+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
231+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
232+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
214233
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
215234
return %0 : vector<8xi8>
216235
}
217236

218237
// CHECK-LABEL: func.func @aligned_sitofp(
219238
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
220-
// CHECK: arith.shli
221-
// CHECK: arith.shrsi
222-
// CHECK: arith.shrsi
223-
// CHECK: shuffle
224-
// CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
239+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
240+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
241+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
242+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
243+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
244+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
245+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
246+
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
225247
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
226248
return %0 : vector<8xf32>
227249
}
228250

251+
// CHECK-LABEL: func.func @aligned_sitofp_2d(
252+
func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
253+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
254+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
255+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
256+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
257+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
258+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
259+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
260+
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
261+
%0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
262+
return %0 : vector<8x32xf32>
263+
}
264+
229265
// CHECK-LABEL: func.func @i4_transpose(
230-
// CHECK-SAME: %[[A:[0-9a-z]*]]
231266
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
232-
// CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
233-
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
234-
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
267+
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
268+
// CHECK: %[[EXT:.*]] = vector.interleave
269+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
270+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
235271
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
236272
return %0 : vector<16x8xi4>
237273
}
238274

239275
// CHECK-LABEL: func.func @i7_transpose(
240-
// CHECK-SAME: %[[A:[0-9a-z]*]]
241276
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
242-
// CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
243-
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
244-
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
277+
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
278+
// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
279+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
280+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
245281
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
246282
return %0 : vector<16x8xi7>
247283
}

0 commit comments

Comments
 (0)