Skip to content

[mlir][Vector] Replace vector.shuffle with vector.interleave in vector narrow type emulation #82550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 22, 2024

Conversation

dcaballe
Copy link
Contributor

This PR replaces the generation of vector.shuffle with vector.interleave in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of vector.interleave allow us to enable these conversion emulations also for multi dimensional vectors.

…ector narrow type emulation

This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR replaces the generation of vector.shuffle with vector.interleave in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of vector.interleave allow us to enable these conversion emulations also for multi dimensional vectors.


Full diff: https://github.com/llvm/llvm-project/pull/82550.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+8-9)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+59-23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
                                                   VectorType preconditionType,
                                                   Operation *op) {
-  if (!preconditionType || preconditionType.getRank() != 1 ||
-      preconditionType.isScalable())
-    return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+  if (!preconditionType || preconditionType.isScalable())
+    return rewriter.notifyMatchFailure(op, "scalable vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
     return rewriter.notifyMatchFailure(op, "types are not vector");
 
+  if (!preconditionType || preconditionType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
     interleaveMaskValues.push_back(i + (vecDimSize / 2));
   }
 
-  return rewriter.create<vector::ShuffleOp>(
-      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
 namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
 template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
 
 // CHECK-LABEL: func.func @aligned_extsi(
 func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi_base_case(
 func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK-NOT: arith.extsi
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
   return %0 : vector<8xi8>
 }
 
 // CHECK-LABEL: func.func @aligned_sitofp(
 func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: shuffle
-  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
   %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK:           %[[EXT:.*]] = vector.interleave
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
   return %0 : vector<16x8xi4>
 }
 
 // CHECK-LABEL: func.func @i7_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK:           %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
   return %0 : vector<16x8xi7>
 }

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2024

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR replaces the generation of vector.shuffle with vector.interleave in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of vector.interleave allow us to enable these conversion emulations also for multi dimensional vectors.


Full diff: https://github.com/llvm/llvm-project/pull/82550.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+8-9)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+59-23)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 36fb66708407b4..9ebe36cd3861e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -724,9 +724,8 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
                                                   VectorType preconditionType,
                                                   Operation *op) {
-  if (!preconditionType || preconditionType.getRank() != 1 ||
-      preconditionType.isScalable())
-    return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+  if (!preconditionType || preconditionType.isScalable())
+    return rewriter.notifyMatchFailure(op, "scalable vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
@@ -743,6 +742,9 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
   if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
     return rewriter.notifyMatchFailure(op, "types are not vector");
 
+  if (!preconditionType || preconditionType.getRank() != 1)
+    return rewriter.notifyMatchFailure(op, "unsupported >1-D vector");
+
   return commonConversionPrecondition(rewriter, preconditionType, op);
 }
 
@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
     interleaveMaskValues.push_back(i + (vecDimSize / 2));
   }
 
-  return rewriter.create<vector::ShuffleOp>(
-      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+  return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
 namespace {
@@ -1008,8 +1009,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
 ///
 ///    arith.sitofp %in : vector<8xi4> to vector<8xf32>
@@ -1018,8 +1018,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
 ///        %1 = arith.shli %0, 4 : vector<4xi8>
 ///        %2 = arith.shrsi %1, 4 : vector<4xi8>
 ///        %3 = arith.shrsi %0, 4 : vector<4xi8>
-///        %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
-///             : vector<4xi8>, vector<4xi8>
+///        %4 = vector.interleave %2, %3 : vector<4xi8>
 ///        %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
 ///
 template <typename ConversionOpType>
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 02063a81664b81..94e78ce40a3c19 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -195,53 +195,89 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
 
 // CHECK-LABEL: func.func @aligned_extsi(
 func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
   return %0 : vector<8xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_2d(
+func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 // CHECK-LABEL: func.func @aligned_extsi_base_case(
 func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: vector.shuffle
-  // CHECK-NOT: arith.extsi
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
   %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
   return %0 : vector<8xi8>
 }
 
 // CHECK-LABEL: func.func @aligned_sitofp(
 func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
-  // CHECK: arith.shli
-  // CHECK: arith.shrsi
-  // CHECK: arith.shrsi
-  // CHECK: shuffle
-  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
   %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
   return %0 : vector<8xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_sitofp_2d(
+func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
+// CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
+// CHECK:           %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
+// CHECK:           %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
+  %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
+  return %0 : vector<8x32xf32>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi4> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
+// CHECK:           %[[EXT:.*]] = vector.interleave
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi4>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
   return %0 : vector<16x8xi4>
 }
 
 // CHECK-LABEL: func.func @i7_transpose(
-//  CHECK-SAME: %[[A:[0-9a-z]*]]
 func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
-  // CHECK: %[[EXT:.*]] = arith.extsi %[[A]] : vector<8x16xi7> to vector<8x16xi8>
-  // CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
-  // CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
+// CHECK-SAME:    %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
+// CHECK:           %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
+// CHECK:           %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
+// CHECK:           %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
   %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
   return %0 : vector<16x8xi7>
 }

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks better, thanks!

@@ -879,8 +881,7 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
interleaveMaskValues.push_back(i + (vecDimSize / 2));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the interleaveMaskValues (and the loop generate it) be removed now? It looks unused.

Also // 3. Interleave low and high i8 elements using a shuffle. is now outdated :)

preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
if (!preconditionType || preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable vector");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything stopping allowing this for scalable vectors too now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say broader validation to make sure the backend is ok with the i4 pieces but I currently don't have a large workload that compiles using scalable vectors.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@dcaballe dcaballe merged commit 386aa7b into llvm:main Feb 22, 2024
@dcaballe dcaballe deleted the narrow-type-interleave branch February 22, 2024 06:52
dcaballe added a commit to iree-org/llvm-project that referenced this pull request Feb 23, 2024
…ector narrow type emulation (llvm#82550)

This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants