-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Replace vector.shuffle
with vector.interleave
in vector narrow type emulation
#82550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ector narrow type emulation This PR replaces the generation of `vector.shuffle` with `vector.interleave` in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of `vector.interleave` allow us to enable these conversion emulations also for multi dimensional vectors.
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR replaces the generation of Full diff: https://github.com/llvm/llvm-project/pull/82550.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
- if (!preconditionType || preconditionType.getRank() != 1 ||
- preconditionType.isScalable())
- return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+ if (!preconditionType || preconditionType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");
+ if (!preconditionType || preconditionType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
return commonConversionPrecondition(rewriter, preconditionType, op);
}
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}
- return rewriter.create<vector::ShuffleOp>(
- loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK-NOT: arith.extsi
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: shuffle
- // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK: %[[EXT:.*]] = vector.interleave
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
// CHECK-LABEL: func.func @i7_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
return %0 : vector<16x8xi7>
}
|
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR replaces the generation of Full diff: https://github.com/llvm/llvm-project/pull/82550.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
VectorType preconditionType,
Operation *op) {
- if (!preconditionType || preconditionType.getRank() != 1 ||
- preconditionType.isScalable())
- return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+ if (!preconditionType || preconditionType.isScalable())
+ return rewriter.notifyMatchFailure(op, "scalable vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
return rewriter.notifyMatchFailure(op, "types are not vector");
+ if (!preconditionType || preconditionType.getRank() != 1)
+ return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
return commonConversionPrecondition(rewriter, preconditionType, op);
}
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}
- return rewriter.create<vector::ShuffleOp>(
- loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+ return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
///
/// arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
/// %1 = arith.shli %0, 4 : vector<4xi8>
/// %2 = arith.shrsi %1, 4 : vector<4xi8>
/// %3 = arith.shrsi %0, 4 : vector<4xi8>
-/// %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-/// : vector<4xi8>, vector<4xi8>
+/// %4 = vector.interleave %2, %3 : vector<4xi8>
/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
///
template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
// CHECK-LABEL: func.func @aligned_extsi(
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
return %0 : vector<8xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_extsi_base_case(
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: vector.shuffle
- // CHECK-NOT: arith.extsi
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
%0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
return %0 : vector<8xi8>
}
// CHECK-LABEL: func.func @aligned_sitofp(
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
- // CHECK: arith.shli
- // CHECK: arith.shrsi
- // CHECK: arith.shrsi
- // CHECK: shuffle
- // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
return %0 : vector<8xf32>
}
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+ %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+ return %0 : vector<8x32xf32>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK: %[[EXT:.*]] = vector.interleave
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
return %0 : vector<16x8xi4>
}
// CHECK-LABEL: func.func @i7_transpose(
-// CHECK-SAME: %[[A:[0-9a-z]*]]
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
- // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
- // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
- // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
return %0 : vector<16x8xi7>
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks better, thanks!
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, | |||
interleaveMaskValues.push_back(i + (vecDimSize / 2)); | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the interleaveMaskValues
(and the loop generate it) be removed now? It looks unused.
Also // 3. Interleave low and high i8 elements using a shuffle.
is now outdated :)
preconditionType.isScalable()) | ||
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector"); | ||
if (!preconditionType || preconditionType.isScalable()) | ||
return rewriter.notifyMatchFailure(op, "scalable vector"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything stopping allowing this for scalable vectors too now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say broader validation to make sure the backend is ok with the i4
pieces but I currently don't have a large workload that compiles using scalable vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…ector narrow type emulation (llvm#82550) This PR replaces the generation of `vector.shuffle` with `vector.interleave` in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of `vector.interleave` allow us to enable these conversion emulations also for multi dimensional vectors.
This PR replaces the generation of
vector.shuffle
withvector.interleave
in the i4 conversions in vector narrow type emulation. The multi dimensional semantics ofvector.interleave
allow us to enable these conversion emulations also for multi dimensional vectors.