-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Rewrites for I2 to I8 signed and unsigned extension #121298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: None (ziereis) ChangesAdds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion. I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x. I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well. Patch is 20.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121298.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..323da627de7bc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // Element 0 (bits 0-1)
+ constexpr int8_t shiftConst6 = 6;
+ auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+ auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+ Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+ Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shiftConst4 = 4;
+ auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+ auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+ Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+ Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shiftConst2 = 2;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+ // Element 3 (bits 6-7)
+ Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extract each i2 element using shifts and masks
+ constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+ auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+ auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+ // Element 0 (bits 0-1)
+ Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shift1 = 2;
+ auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+ auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+ Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+ Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shift2 = 4;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+ // Element 3 (bits 6-7)
+ constexpr int8_t shift3 = 6;
+ auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+ auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+ Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+ Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
@@ -1438,11 +1549,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
// Perform the rewrite.
Value subByteExt;
if (isSigned) {
- subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ subByteExt =
+ rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ else {
+ subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
} else {
- subByteExt =
- rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+ subByteExt =
+ rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ } else {
+ subByteExt =
+ rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
}
// Finalize the rewrite.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
truncOp)))
return failure();
+ // not supported currently.
+ if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ return failure();
+
// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
return %0 : vector<3x8x32xi4>
}
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+ return %0 : vector<8xi2>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.
A TODO + negative test would be sufficient for me.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { | |||
truncOp))) | |||
return failure(); | |||
|
|||
// not supported currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I'd move this near commonConversionPrecondition
- this check is basically "complementing" that hook.
// not supported currently. | |
// TODO: Add support for truncating to i2. |
Also, I'd add a negative test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Also, a negative test should exist. I just check that I don't match on an i2 to i8 truncation, is that not sufficient?
Thanks a lot for the review. I believe I have addressed all the comments. I also noticed a superfluous ```AndIOp`` for the unsigned case, so I removed it and adjusted the tests accordingly. |
cc @lialan who recently works on this area. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic of signed and unsigned conversion is very similar. The only difference is that they use different extractNBits*
functions. I think we can pass the function to the method, so we no longer need to duplicate the logic. E.g.,
using ExtractNBitsFn = std::function<Value(PatternRewriter&, Location, Value, int, int)>;
static Value rewriteI4ToI8(PatternRewriter &rewriter, Location loc,
Value srcValue, ExtractNBitsFn extFn) {
// ...
Value low = extFn(rewriter, loc, i8Vector, 0, 4);
Value high = extFn(rewriter, loc, i8Vector, 4, 4);
// ...
}
Thanks again for the review. All comments should be addressed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the updates so far!
/// Extracts a signed N-bit sequence from each element of an 8-bit vector, | ||
/// starting at the specified bit index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused about the naming and this description.
IIUC, this method will:
- for every byte
b
insrc
(which is a vector of bytes), - extracts
numBits
starting atbitIdx
(let's call itinputVal
), and - returns a byte matching the value encoded in
inputVal
.
So this method is more like extractNBitsAndReturnAsByte
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- for every byte
b
insrc
(which is a vector of bytes),- extracts
numBits
starting atbitIdx
(let's call itinputVal
), and- returns a byte matching the value encoded in
inputVal
.
it will extract numBits
for every byte of src
at bitIdx
and will return a vector of bytes, the resultType will always be the same as the srcType.
So for example lets say numBits
is 4, it will treat the inputVal
as a i4
and (sign)ext it to a i8
value.
im not sure about the name either, maybe extractNBitsAnd(Sign)ExtToI8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extractNbitsPerByteAndExtendToI8
?
Am I correct that this method assumes that the src and dst element type is i8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
/// shr = src >> 6 = [00010110] | ||
/// result = shr & mask = [00000010] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't arith::shl
-> arith::shrui
work here just fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would, but you would end up with more instructions in total, since to extract the bits starting a position 0 you would still have to do the shl + shr whereas with the mask you can just do the mask once.
Example for 4 bit:
Shifts:
src : i8
shl = arith::shl(src)
low = artih::shrui(shl)
high = arith::shrui(src)
Mask:
src: i8
low = mask(src, 15)
high = artih::shrui(src)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In ideal world, we'd keep extractNBitsFromVectorSigned
and extractNBitsFromVectorUnsinged
very consistent, so I'm trying to understand the reason for divergence.
Example for 4 bit:
Shifts: src : i8
shl = arith::shl(src)
low = artih::shrui(shl)
high = arith::shrui(src)
Let me make sure I got this right. I assume that we are extracting 2 bits at index 0 (indexing from MSB, i.e. left-most bit):
src: [10 10 11 00]
shl: arith.shl(src) --> [10 10 11 00]
shr: arith.shrui(shl) --> [00 00 00 10]
Why do we need another shr
in your example? (i.e., what's low
and high
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit: Im Sorry, i just looked at the example again and there is a mistake, it should be src >> 2 not 6. So thats probably where the confusion is from.
okay, so my understanding is that when talking about bits the convention is to "read it right to left", so the LSB is at index 0. So the "low bits" are the right most bits and the "high bits" are the left most. I will add this to the documentation
src: [ 1 | 0 | 1 | 0 | 1 | 1 | 0 | 0 ]
indices: [ 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 ]
So i could always extract N bits starting from 0 with a single arith::AndIOp + mask.
If i would use the shl + shrui pattern i would need to use 2 instructions.
So this is basically a small shortcut i can take with unsigned numbers, but for the signed case i have to always do it with the shrsi for the sign.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, so my understanding is that when talking about bits the convention is to "read it right to left", so the LSB is at index 0. So the "low bits" are the right most bits and the "high bits" are the left most.
Yes, that makes more sense 😅 But then, when storing a vector of 8 elements in memory, it's usually like: v[0], v[1], v[2], ...
. Not to mention the split into little-endian/big-endian (but we are not storing anything in memory here, so not relevant). My point being - this is worth documenting.
I will add this to the documentation
That was going to be my next request 😅
So this is basically a small shortcut i can take with unsigned numbers, but for the signed case i have to always do it with the shrsi for the sign.
TBH, this is where I'd really expect LLVM to do it's job (as in, I wouldn't consider optimising this much in MLIR). But your measurements might suggest you otherwise?
To me, it would make a lot of sense if you merged extractNBitsFromVectorSigned
and extractNBitsFromVectorUnsinged
- the only difference would be the right shift. Consistency is great :)
That said, I don't feel that strongly about it and I admit that it's a optimisation. If you decide to leave arith::ShlOp + arith::AndIOp
, then could you add a comment in extractNBitsFromVectorUnsinged
to document the reason for divergence? Something along the lines:
NOTE: Similarly to
extractNBitsFromVectorUnsinged
, this could be used with arith::ShlOp + arith::ShrSOp. However, by using
arith::ShlOp + arith::AndIOp`, we are eliminating shift left when the index is 0.
I know that this feels obvious today, but I always try to optimise for my future self who will forget :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not measured if it would make any difference, i will look into this, for now i have just created the note. However changing this would also mean i have to change all the tests for the i4 to i8 rewrites.
✅ With the latest revision this PR passed the C/C++ code formatter. |
- extracts repeated code into functions - reorder tests - improve naming
d543d5a
to
b975051
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love the updated comments 🙏🏻
/// Extracts a signed N-bit sequence from each element of an 8-bit vector, | ||
/// starting at the specified bit index. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extractNbitsPerByteAndExtendToI8
?
Am I correct that this method assumes that the src and dst element type is i8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the updates, this is looking really neat 🙏🏻 A few final comments, nothing major.
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth(); | ||
assert(8 % srcBitwidth == 0 && | ||
"Unsupported sub-byte type (not a divisor of i8)"); | ||
int64_t bitwidthFactor = 8 / srcBitwidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep the naming consistent within the file, could this be numSrcElemsPerByte
?
Similar variable elsewhere:
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, there are some other places that refer to the same name but im not sure if they refer to the same thing
|
||
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; | ||
const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why hard-code this to 8?
Also, comment for L1105 (sadly GitHub doesn't allow comments for things outside the diff) - please update the error msg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this check since the dst in this case will always be a i8 vector.
For example:
%0 = arith.extui %a : vector<4xi2> to vector<4xi32>
should be possible because its only important that i can pack the 4 i2 values into a 1xi8 vector (otherwise i would have a invalid bitcast).
But i also added a test that checks something like this fails:
%0 = arith.extui %a : vector<2xi2> to vector<2xi32>
// CHECK-LABEL: func.func @aligned_i4_trailing_dim_not_multiple( | ||
func.func @aligned_i4_trailing_dim_not_multiple(%a: vector<1xi4>) -> vector<1xi8> { | ||
// CHECK-NOT: arith.bitcast | ||
// CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8> | ||
%0 = arith.extsi %a : vector<1xi4> to vector<1xi8> | ||
return %0 : vector<1xi8> | ||
} | ||
|
||
// CHECK-LABEL: func.func @aligned_i2_trailing_dim_not_multiple( | ||
func.func @aligned_i2_trailing_dim_not_multiple(%a: vector<2xi2>) -> vector<2xi8> { | ||
// CHECK-NOT: arith.bitcast | ||
// CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8> | ||
%0 = arith.extsi %a : vector<2xi2> to vector<2xi8> | ||
return %0 : vector<2xi8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the unaligned cases, right? Please move to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was meant like more of a negative test for the alignedConversionPrecondition. Also i think the file you linked refers to the void vector::populateVectorNarrowTypeEmulationPatterns and i changed the vector::populateVectorNarrowTypeRewritePatterns (in the same file) if i understood correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, now I see. Then, still, aligned
--> unalagined
;-) Also, to match the naming convention that follows this:
@aligned_i4_trailing_dim_not_multiple --> @unaligned_extsi_i4_to_i8_trailing_dim_not_multiple
In fact, I'd do this:
// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
func.func @unaligned_extsi_i4_to_i8
Similar comment for the other negative test.
Co-authored-by: Andrzej Warzyński <[email protected]>
Co-authored-by: Andrzej Warzyński <[email protected]>
Co-authored-by: Andrzej Warzyński <[email protected]>
Thanks again for the comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the updated documentation, thanks a lot! I just have a final optional nit.
@banach-space Sorry, I didn’t mean to ignore the comments. I somehow started my own review on the PR or something, so all the comments were only "pending." I didn’t know what this meant, but it appears they were only visible to me. 😅 I hope they’re visible now! |
No worries, GitHub is pretty bad for tracking threads of conversation. Hence my "ping" - I suspected that some things were displaying for only one of us ;-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some minor requests.
This is quite intricate, and I really appreciate the effort you've put into implementing this and so patiently addressing all my comments.
// CHECK-LABEL: func.func @aligned_i4_trailing_dim_not_multiple( | ||
func.func @aligned_i4_trailing_dim_not_multiple(%a: vector<1xi4>) -> vector<1xi8> { | ||
// CHECK-NOT: arith.bitcast | ||
// CHECK: arith.extsi %[[IN:.*]] : vector<1xi4> to vector<1xi8> | ||
%0 = arith.extsi %a : vector<1xi4> to vector<1xi8> | ||
return %0 : vector<1xi8> | ||
} | ||
|
||
// CHECK-LABEL: func.func @aligned_i2_trailing_dim_not_multiple( | ||
func.func @aligned_i2_trailing_dim_not_multiple(%a: vector<2xi2>) -> vector<2xi8> { | ||
// CHECK-NOT: arith.bitcast | ||
// CHECK: arith.extsi %[[IN:.*]] : vector<2xi2> to vector<2xi8> | ||
%0 = arith.extsi %a : vector<2xi2> to vector<2xi8> | ||
return %0 : vector<2xi8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, now I see. Then, still, aligned
--> unalagined
;-) Also, to match the naming convention that follows this:
@aligned_i4_trailing_dim_not_multiple --> @unaligned_extsi_i4_to_i8_trailing_dim_not_multiple
In fact, I'd do this:
// Negative test - the trailing dim 2 is not a multiple of 4 (i.e. 8 / 2).
func.func @unaligned_extsi_i4_to_i8
Similar comment for the other negative test.
Co-authored-by: Andrzej Warzyński <[email protected]>
The comments should be resolved now. Thanks as well for all the reviews and being patient with me :) |
Works both ways :) Do you have commit access? |
no, i would be happy if you could merge it. |
Landed 🥳 Thanks for working on this 🙏🏻 Note, GitHub added me as a co-author. AFAIK, that's the default behaviour when squashing commits like this: (IIUC, you've incorporated my suggestion through GitHub's UI). I don't really mind, but obviously you've done 99% of the work here 😅 |
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
This PR aims at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. Implements llvm#123630. **CHANGE 1** Renames `srcBits/dstBits` to `oldBits/newBits` to improve consistency in naming within the file. This is illustrated below: ```cpp // Extracted from VectorEmulateNarrowType.cpp Type oldElementType = op.getType().getElementType(); Type newElementType = convertedType.getElementType(); // BEFORE (mixing old/new and src/dst): // int srcBits = oldElementType.getIntOrFloatBitWidth(); // int dstBits = newElementType.getIntOrFloatBitWidth(); // AFTER (consistently using old/new): int oldBits = oldElementType.getIntOrFloatBitWidth(); int newBits = newElementType.getIntOrFloatBitWidth(); ``` Also adds some comments and unifies related "rewriter notification" messages. **CHANGE 2** Renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **CHANGE 3** Replaces `isUnalignedEmulation` with `isFullyAligned` Note, `isUnalignedEmulation` is always computed following a "per-element-alignment" condition: ```cpp // Check per-element alignment. if (newBits % oldBits != 0) { return rewriter.notifyMatchFailure(op, "unalagined element types"); } // (...) bool isUnalignedEmulation = origElements % elementsPerContainerType != 0; ``` Given that `isUnalignedEmulation` captures only one of two conditions required for "full alignment", it should be re-named as `isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and renamed it as `isFullyAligned`: ```cpp bool isFullyAligned = origElements % elementsPerContainerType == 0; ``` **CHANGE 4** Unifies various comments throughout the file (for consistency). **CHANGE 5** Adds new comments throughout the file and adds TODOs where high-level comments are missing. **CHANGE 6** Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). **CHANGE 7** Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). **CHANGE 8** Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition).
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). NEXT STEPS (1): We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified. NEXT STEPS (2): With this PR, I am introducing explicit references to "sub-byte" as that is effectively what this logic is used of (i.e. for emulating "sub-byte" types). We should either generalise (which would include increasing test coverage) or restrict everything to "sub-byte" type emulation.
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In llvm#121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). NEXT STEPS (1): We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified. NEXT STEPS (2): With this PR, I am introducing explicit references to "sub-byte" as that is effectively what this logic is used of (i.e. for emulating "sub-byte" types). We should either generalise (which would include increasing test coverage) or restrict everything to "sub-byte" type emulation.
Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion.
I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.
I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.