Skip to content

[mlir] Rewrites for I2 to I8 signed and unsigned extension #121298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 15, 2025

Conversation

ziereis
Copy link
Contributor

@ziereis ziereis commented Dec 29, 2024

Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion.

I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.

I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.

@llvmbot
Copy link
Member

llvmbot commented Dec 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: None (ziereis)

Changes

Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion.

I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.

I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.


Patch is 20.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121298.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+131-6)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+145-4)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..323da627de7bc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
   unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
   unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
 
-  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
-  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+  // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
       (dstElemBitwidth % srcElemBitwidth) != 0)
     return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
 
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // Element 0 (bits 0-1)
+  constexpr int8_t shiftConst6 = 6;
+  auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+  auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+  Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+  Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shiftConst4 = 4;
+  auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+  auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+  Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+  Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+  // Element 2 (bits 4-5)
+  constexpr int8_t shiftConst2 = 2;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+  // Element 3 (bits 6-7)
+  Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+                                      Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(2) &&
+         "Expected i2 type");
+
+  // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i2Toi8BitwidthFactor = 4;
+  i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extract each i2 element using shifts and masks
+  constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+  auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+  auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+  // Element 0 (bits 0-1)
+  Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+  // Element 1 (bits 2-3)
+  constexpr int8_t shift1 = 2;
+  auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+  auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+  Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+  Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+  // Element 2 (bits 4-5)
+  constexpr int8_t shift2 = 4;
+  auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+  auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+  Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+  Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+  // Element 3 (bits 6-7)
+  constexpr int8_t shift3 = 6;
+  auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+  auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+  Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+  Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+  // interleave all 4 elements by first interleaving even elements and then odd
+  // elem0 = [0,0,0,0]
+  // elem1 = [1,1,1,1]
+  // elem2 = [2,2,2,2]
+  // elem3 = [3,3,3,3]
+  // 02    = [0,2,0,2]
+  // 13    = [1,3,1,3]
+  // 0213  = [0,1,2,3]
+  Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+  Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+  return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
 /// ops that take advantage of high-level information to avoid leaving LLVM to
 /// scramble with peephole optimizations.
@@ -1438,11 +1549,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
     // Perform the rewrite.
     Value subByteExt;
     if (isSigned) {
-      subByteExt =
-          rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+        subByteExt =
+            rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      else {
+        subByteExt =
+            rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     } else {
-      subByteExt =
-          rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+        subByteExt =
+            rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      } else {
+        subByteExt =
+            rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+      }
     }
 
     // Finalize the rewrite.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
                                              truncOp)))
       return failure();
 
+    // not supported currently.
+    if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+      return failure();
+
     // Create a new iX -> i8 truncation op.
     Location loc = truncOp.getLoc();
     auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+  %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+  return %0 : vector<8x32xi32>
+}
+
 
 // CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
 func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
   return %0 : vector<3x8x32xi4>
 }
 
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+  // CHECK-NOT: arith.bitcast
+  // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+  return %0 : vector<8xi2>
+}
+
 // CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
 func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
 // CHECK-SAME:                             %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
   return %0 : vector<8xi32>
 }
 
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK-SAME:                                %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
 // CHECK:           %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
 // CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
   return %0 : vector<8x32xi32>
 }
 
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK:           %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK:           %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+  %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME:      %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK:           %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK:           %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK:           %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK:           %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK:           %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK:           %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK:           %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK:           %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK:           %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK:           %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.

A TODO + negative test would be sufficient for me.

@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
truncOp)))
return failure();

// not supported currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I'd move this near commonConversionPrecondition - this check is basically "complementing" that hook.

Suggested change
// not supported currently.
// TODO: Add support for truncating to i2.

Also, I'd add a negative test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Also, a negative test should exist. I just check that I don't match on an i2 to i8 truncation, is that not sufficient?

@ziereis
Copy link
Contributor Author

ziereis commented Jan 6, 2025

Thanks a lot for the review. I believe I have addressed all the comments. I also noticed a superfluous ```AndIOp`` for the unsigned case, so I removed it and adjusted the tests accordingly.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 7, 2025

cc @lialan who recently works on this area.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic of signed and unsigned conversion is very similar. The only difference is that they use different extractNBits* functions. I think we can pass the function to the method, so we no longer need to duplicate the logic. E.g.,

using ExtractNBitsFn = std::function<Value(PatternRewriter&, Location, Value, int, int)>;

static Value rewriteI4ToI8(PatternRewriter &rewriter, Location loc,
                                    Value srcValue, ExtractNBitsFn extFn) {
// ...
  Value low = extFn(rewriter, loc, i8Vector, 0, 4);
  Value high = extFn(rewriter, loc, i8Vector, 4, 4);
// ...
}

@ziereis
Copy link
Contributor Author

ziereis commented Jan 7, 2025

Thanks again for the review. All comments should be addressed.

@ziereis ziereis requested a review from hanhanW January 7, 2025 18:30
@nikic nikic changed the title Rewrites for I2 to I8 signed and unsigned extension [mlir] Rewrites for I2 to I8 signed and unsigned extension Jan 7, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates so far!

Comment on lines 1197 to 1206
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused about the naming and this description.

IIUC, this method will:

  • for every byte b in src (which is a vector of bytes),
  • extracts numBits starting at bitIdx (let's call it inputVal), and
  • returns a byte matching the value encoded in inputVal.

So this method is more like extractNBitsAndReturnAsByte?

Copy link
Contributor Author

@ziereis ziereis Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • for every byte b in src (which is a vector of bytes),
  • extracts numBits starting at bitIdx (let's call it inputVal), and
  • returns a byte matching the value encoded in inputVal.

it will extract numBits for every byte of src at bitIdx and will return a vector of bytes, the resultType will always be the same as the srcType.
So for example lets say numBits is 4, it will treat the inputVal as a i4 and (sign)ext it to a i8 value.

im not sure about the name either, maybe extractNBitsAnd(Sign)ExtToI8 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extractNbitsPerByteAndExtendToI8?

Am I correct that this method assumes that the src and dst element type is i8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Comment on lines 1233 to 1234
/// shr = src >> 6 = [00010110]
/// result = shr & mask = [00000010]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't arith::shl -> arith::shrui work here just fine?

Copy link
Contributor Author

@ziereis ziereis Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would, but you would end up with more instructions in total, since to extract the bits starting a position 0 you would still have to do the shl + shr whereas with the mask you can just do the mask once.

Example for 4 bit:

Shifts:
src : i8
shl = arith::shl(src)
low = artih::shrui(shl)
high = arith::shrui(src)

Mask:
src: i8
low = mask(src, 15)
high = artih::shrui(src)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In ideal world, we'd keep extractNBitsFromVectorSigned and extractNBitsFromVectorUnsinged very consistent, so I'm trying to understand the reason for divergence.

Example for 4 bit:

Shifts: src : i8
shl = arith::shl(src)
low = artih::shrui(shl)
high = arith::shrui(src)

Let me make sure I got this right. I assume that we are extracting 2 bits at index 0 (indexing from MSB, i.e. left-most bit):

src: [10 10 11 00]
shl: arith.shl(src) --> [10 10 11 00]
shr: arith.shrui(shl) --> [00 00 00 10]

Why do we need another shr in your example? (i.e., what's low and high?)

Copy link
Contributor Author

@ziereis ziereis Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Im Sorry, i just looked at the example again and there is a mistake, it should be src >> 2 not 6. So thats probably where the confusion is from.

okay, so my understanding is that when talking about bits the convention is to "read it right to left", so the LSB is at index 0. So the "low bits" are the right most bits and the "high bits" are the left most. I will add this to the documentation

src:      [ 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 ]
indices:  [ 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 ]

So i could always extract N bits starting from 0 with a single arith::AndIOp + mask.
If i would use the shl + shrui pattern i would need to use 2 instructions.

So this is basically a small shortcut i can take with unsigned numbers, but for the signed case i have to always do it with the shrsi for the sign.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, so my understanding is that when talking about bits the convention is to "read it right to left", so the LSB is at index 0. So the "low bits" are the right most bits and the "high bits" are the left most.

Yes, that makes more sense 😅 But then, when storing a vector of 8 elements in memory, it's usually like: v[0], v[1], v[2], .... Not to mention the split into little-endian/big-endian (but we are not storing anything in memory here, so not relevant). My point being - this is worth documenting.

I will add this to the documentation

That was going to be my next request 😅

So this is basically a small shortcut i can take with unsigned numbers, but for the signed case i have to always do it with the shrsi for the sign.

TBH, this is where I'd really expect LLVM to do it's job (as in, I wouldn't consider optimising this much in MLIR). But your measurements might suggest you otherwise?

To me, it would make a lot of sense if you merged extractNBitsFromVectorSigned and extractNBitsFromVectorUnsinged - the only difference would be the right shift. Consistency is great :)

That said, I don't feel that strongly about it and I admit that it's a optimisation. If you decide to leave arith::ShlOp + arith::AndIOp, then could you add a comment in extractNBitsFromVectorUnsinged to document the reason for divergence? Something along the lines:

NOTE: Similarly to extractNBitsFromVectorUnsinged, this could be used with arith::ShlOp + arith::ShrSOp. However, by using arith::ShlOp + arith::AndIOp`, we are eliminating shift left when the index is 0.

I know that this feels obvious today, but I always try to optimise for my future self who will forget :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not measured if it would make any difference, i will look into this, for now i have just created the note. However changing this would also mean i have to change all the tests for the i4 to i8 rewrites.

Copy link

github-actions bot commented Jan 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ziereis ziereis force-pushed the i2toi8-ext-rewrites branch from d543d5a to b975051 Compare January 9, 2025 15:24
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love the updated comments 🙏🏻

Comment on lines 1197 to 1206
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extractNbitsPerByteAndExtendToI8?

Am I correct that this method assumes that the src and dst element type is i8?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates, this is looking really neat 🙏🏻 A few final comments, nothing major.

int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
assert(8 % srcBitwidth == 0 &&
"Unsupported sub-byte type (not a divisor of i8)");
int64_t bitwidthFactor = 8 / srcBitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the naming consistent within the file, could this be numSrcElemsPerByte?

Similar variable elsewhere:

const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, there are some other places that refer to the same name but im not sure if they refer to the same thing


const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hard-code this to 8?

Also, comment for L1105 (sadly GitHub doesn't allow comments for things outside the diff) - please update the error msg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this check since the dst in this case will always be a i8 vector.

For example:

%0 = arith.extui %a : vector<4xi2> to vector<4xi32>

should be possible because its only important that i can pack the 4 i2 values into a 1xi8 vector (otherwise i would have a invalid bitcast).

But i also added a test that checks something like this fails:

%0 = arith.extui %a : vector<2xi2> to vector<2xi32>

Comment on lines 197 to 210
// CHECK-LABEL: func.func @aligned_i4_trailing_dim_not_multiple(
func.func @aligned_i4_trailing_dim_not_multiple(%a: vector<1xi4>) -> vector<1xi8> {
// CHECK-NOT: arith.bitcast
// CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
%0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
return %0 : vector<1xi8>
}

// CHECK-LABEL: func.func @aligned_i2_trailing_dim_not_multiple(
func.func @aligned_i2_trailing_dim_not_multiple(%a: vector<2xi2>) -> vector<2xi8> {
// CHECK-NOT: arith.bitcast
// CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
%0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
return %0 : vector<2xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was meant like more of a negative test for the alignedConversionPrecondition. Also i think the file you linked refers to the void vector::populateVectorNarrowTypeEmulationPatterns and i changed the vector::populateVectorNarrowTypeRewritePatterns (in the same file) if i understood correctly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, now I see. Then, still, aligned --> unalagined ;-) Also, to match the naming convention that follows this:

@aligned_i4_trailing_dim_not_multiple --> @unaligned_extsi_i4_to_i8_trailing_dim_not_multiple

In fact, I'd do this:

// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).

func.func @unaligned_extsi_i4_to_i8

Similar comment for the other negative test.

@ziereis
Copy link
Contributor Author

ziereis commented Jan 10, 2025

Thanks again for the comments

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the updated documentation, thanks a lot! I just have a final optional nit.

@ziereis
Copy link
Contributor Author

ziereis commented Jan 13, 2025

@banach-space Sorry, I didn’t mean to ignore the comments. I somehow started my own review on the PR or something, so all the comments were only "pending." I didn’t know what this meant, but it appears they were only visible to me. 😅 I hope they’re visible now!

@banach-space
Copy link
Contributor

@banach-space Sorry, I didn’t mean to ignore the comments.

No worries, GitHub is pretty bad for tracking threads of conversation. Hence my "ping" - I suspected that some things were displaying for only one of us ;-)

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some minor requests.

This is quite intricate, and I really appreciate the effort you've put into implementing this and so patiently addressing all my comments.

Comment on lines 197 to 210
// CHECK-LABEL: func.func @aligned_i4_trailing_dim_not_multiple(
func.func @aligned_i4_trailing_dim_not_multiple(%a: vector<1xi4>) -> vector<1xi8> {
// CHECK-NOT: arith.bitcast
// CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8>
%0 = arith.extsi %a : vector<1xi4> to vector<1xi8>
return %0 : vector<1xi8>
}

// CHECK-LABEL: func.func @aligned_i2_trailing_dim_not_multiple(
func.func @aligned_i2_trailing_dim_not_multiple(%a: vector<2xi2>) -> vector<2xi8> {
// CHECK-NOT: arith.bitcast
// CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8>
%0 = arith.extsi %a : vector<2xi2> to vector<2xi8>
return %0 : vector<2xi8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, now I see. Then, still, aligned --> unalagined ;-) Also, to match the naming convention that follows this:

@aligned_i4_trailing_dim_not_multiple --> @unaligned_extsi_i4_to_i8_trailing_dim_not_multiple

In fact, I'd do this:

// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).

func.func @unaligned_extsi_i4_to_i8

Similar comment for the other negative test.

@ziereis
Copy link
Contributor Author

ziereis commented Jan 14, 2025

The comments should be resolved now. Thanks as well for all the reviews and being patient with me :)

@banach-space
Copy link
Contributor

Thanks as well for all the reviews and being patient with me :)

Works both ways :) Do you have commit access?

@ziereis
Copy link
Contributor Author

ziereis commented Jan 14, 2025

no, i would be happy if you could merge it.

@banach-space banach-space merged commit 929eb50 into llvm:main Jan 15, 2025
8 checks passed
@banach-space
Copy link
Contributor

no, i would be happy if you could merge it.

Landed 🥳 Thanks for working on this 🙏🏻

Note, GitHub added me as a co-author. AFAIK, that's the default behaviour when squashing commits like this:

(IIUC, you've incorporated my suggestion through GitHub's UI). I don't really mind, but obviously you've done 99% of the work here 😅

banach-space added a commit to banach-space/llvm-project that referenced this pull request Jan 19, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Jan 20, 2025
This  PR aims at improving "VectorEmulateNarrowType.cpp". This is mainly
minor refactoring, no major functional changes are made/added.
Implements llvm#123630.

**CHANGE 1**
Renames `srcBits/dstBits` to `oldBits/newBits` to improve consistency in
naming within the file. This is illustrated below:

```cpp
  // Extracted from VectorEmulateNarrowType.cpp
  Type oldElementType = op.getType().getElementType();
  Type newElementType = convertedType.getElementType();

  // BEFORE (mixing old/new and src/dst):
  // int srcBits = oldElementType.getIntOrFloatBitWidth();
  // int dstBits = newElementType.getIntOrFloatBitWidth();

  // AFTER (consistently using old/new):
  int oldBits = oldElementType.getIntOrFloatBitWidth();
  int newBits = newElementType.getIntOrFloatBitWidth();
```

Also adds some comments and unifies related "rewriter notification"
messages.

**CHANGE 2**
Renames the variable "scale". Note, "scale" could mean either:

  * "original-elements-per-emulated-type", or
  * "emulated-elements-per-original-type".

While from the context it is clear that it's always the former (original
type is always a sub-byte type and the emulated type is usually `i8`),
this PR reduces the cognitive load by making this clear.

**CHANGE 3**
Replaces `isUnalignedEmulation` with `isFullyAligned`

Note, `isUnalignedEmulation` is always computed following a
"per-element-alignment" condition:
```cpp
// Check per-element alignment.
if (newBits % oldBits != 0) {
  return rewriter.notifyMatchFailure(op, "unalagined element types");
}

// (...)

bool isUnalignedEmulation = origElements % elementsPerContainerType != 0;
```

Given that `isUnalignedEmulation` captures only one of two conditions
required for "full alignment", it should be re-named as
`isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and
renamed it as `isFullyAligned`:

```cpp
bool isFullyAligned = origElements % elementsPerContainerType == 0;
```

**CHANGE 4**
Unifies various comments throughout the file (for consistency).

**CHANGE 5**
Adds new comments throughout the file and adds TODOs where high-level
comments are missing.

**CHANGE 6**
Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

**CHANGE 7**
Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

**CHANGE 8**
Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 6, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 15, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 25, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 26, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

4. Update alignedConversionPrecondition (4):

NEXT STEPS:

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 1, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

NEXT STEPS (1):

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.

NEXT STEPS (2):

With this PR, I am introducing explicit references to "sub-byte" as
that is effectively what this logic is used of (i.e. for emulating
"sub-byte" types). We should either generalise (which would include
increasing test coverage) or restrict everything to "sub-byte" type
emulation.
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 15, 2025
This is PR 4 in a series of N patches aimed at improving
"VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no
major functional changes are made/added.

1. Update `alignedConversionPrecondition` (1):

This method didn't require the vector type for the "destination"
argument. The underlying element type is sufficient. The corresponding
argument has been renamed as `multiByteScalarTy` - this is meant as the
multi-byte emulated type (`i8`, `i16`, `i32`, etc).

2. Update `alignedConversionPrecondition` (2):

In llvm#121298, we replaced `dstElemBitwidt` in this calculation:

```cpp
  const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
```

with the hard-coded value of 8:
```cpp
  const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
```

That was correct as for the patterns for which this hook was/is used:

  * `RewriteAlignedSubByteIntExt`,
  * `RewriteAlignedSubByteIntTrunc`.

The destination type (or, more precisely, the emulated type) was always
`i8`.

In this PR, I am switching back to a more generic approach - the
calculation should take into account the bit-width of the emulated type.

Note that at the call sites I am passing `i8` as the emulated type, so the
end-result is effectively identical. However, the intent is clearer, i.e.,
the underlying value is 8 because the emulated type happens to be `i8`
(as opposed using a magic number).

3. Update alignedConversionPrecondition (3):

The final check has been replaced with a new helper method,
`isSubByteVecFittable`. This new method is also re-used within the code
and hopefully will allow us more code re-use moving forward (to avoid
re-implementing the same condition).

NEXT STEPS (1):

We need to clarify the meaning of "source" and "destination" types.
Currently the usage is ambiguous.

For example, for this `arith.extsi` Op, `vector<8xi2>` and
`vector<8xi32>` are the "source" and "destination" types, respectively:

```mlir
  %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32>
}
```

However, patterns like `RewriteAlignedSubByteIntExt` introduce
`vector.bitcast` Ops like this:

```mlir
  %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8>
```

I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as
the destination types and that should be clarified.

NEXT STEPS (2):

With this PR, I am introducing explicit references to "sub-byte" as
that is effectively what this logic is used of (i.e. for emulating
"sub-byte" types). We should either generalise (which would include
increasing test coverage) or restrict everything to "sub-byte" type
emulation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants