-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR] support dynamic indexing of vector.maskedload
in VectorEmulateNarrowTypes
#115070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: lialan (lialan) ChangesBased on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector. the dynamic parts are constructed by multiple Full diff: https://github.com/llvm/llvm-project/pull/115070.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
+ assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+ TypedValue<VectorType> source,
+ Value dest, OpFoldResult destOffsetVar,
+ int64_t length) {
+ assert(length > 0 && "length must be greater than 0");
+ for (int i = 0; i < length; ++i) {
+ Value insertLoc =
+ 1 == 0
+ ? destOffsetVar.dyn_cast<Value>()
+ : rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
+ auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+ dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+ }
+ return dest;
+}
+
/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
newLoad);
-};
+}
namespace {
@@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic intra vector offset
- return failure();
- }
-
- FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
- *foldedIntraVectorOffset);
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+ FailureOr<Operation *> newMask = getCompressedMaskOp(
+ rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
return failure();
+ Value passthru = op.getPassThru();
+
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
- Value passthru = op.getPassThru();
- if (isUnalignedEmulation) {
- // create an empty vector of the new type
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
- passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
- *foldedIntraVectorOffset);
+ auto emptyVector = rewriter.create<arith::ConstantOp>(
+ loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ passthru = staticallyInsertSubvector(
+ rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+ }
+ } else {
+ passthru = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+ emptyVector, linearizedInfo.intraDataOffset, origElements);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +606,36 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
Value mask = op.getMask();
- if (isUnalignedEmulation) {
- auto newSelectMaskType =
- VectorType::get(numElements * scale, rewriter.getI1Type());
- // TODO: can fold if op's mask is constant
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
- mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
- *foldedIntraVectorOffset);
+ auto newSelectMaskType =
+ VectorType::get(numElements * scale, rewriter.getI1Type());
+ // TODO: try to fold if op's mask is constant
+ auto emptyMask = rewriter.create<arith::ConstantOp>(
+ loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+ *foldedIntraVectorOffset);
+ }
+ } else {
+ mask = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ linearizedInfo.intraDataOffset, origElements);
}
Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ } else {
+ auto resultVector = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ result = dynamicallyExtractSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ linearizedInfo.intraDataOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -659,10 +693,9 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- auto maxIntraVectorOffset =
- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
- llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,54 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %cst = arith.constant dense<0> : vector<3x3xi2>
+ %c2 = arith.constant 2 : index
+ %mask = vector.constant_mask [3] : vector<3xi1>
+ %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+ memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
|
@llvm/pr-subscribers-mlir-vector Author: lialan (lialan) ChangesBased on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector. the dynamic parts are constructed by multiple Full diff: https://github.com/llvm/llvm-project/pull/115070.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
+ assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
return dest;
}
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+ TypedValue<VectorType> source,
+ Value dest, OpFoldResult destOffsetVar,
+ int64_t length) {
+ assert(length > 0 && "length must be greater than 0");
+ for (int i = 0; i < length; ++i) {
+ Value insertLoc =
+ 1 == 0
+ ? destOffsetVar.dyn_cast<Value>()
+ : rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+ rewriter.create<arith::ConstantIndexOp>(loc, i));
+ auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+ dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+ }
+ return dest;
+}
+
/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
newLoad);
-};
+}
namespace {
@@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- if (!foldedIntraVectorOffset) {
- // unimplemented case for dynamic intra vector offset
- return failure();
- }
-
- FailureOr<Operation *> newMask =
- getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
- *foldedIntraVectorOffset);
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+ FailureOr<Operation *> newMask = getCompressedMaskOp(
+ rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
if (failed(newMask))
return failure();
+ Value passthru = op.getPassThru();
+
auto numElements =
- llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
- Value passthru = op.getPassThru();
- if (isUnalignedEmulation) {
- // create an empty vector of the new type
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
- passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
- *foldedIntraVectorOffset);
+ auto emptyVector = rewriter.create<arith::ConstantOp>(
+ loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ passthru = staticallyInsertSubvector(
+ rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+ }
+ } else {
+ passthru = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+ emptyVector, linearizedInfo.intraDataOffset, origElements);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +606,36 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
Value mask = op.getMask();
- if (isUnalignedEmulation) {
- auto newSelectMaskType =
- VectorType::get(numElements * scale, rewriter.getI1Type());
- // TODO: can fold if op's mask is constant
- auto emptyVector = rewriter.create<arith::ConstantOp>(
- loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
- mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
- *foldedIntraVectorOffset);
+ auto newSelectMaskType =
+ VectorType::get(numElements * scale, rewriter.getI1Type());
+ // TODO: try to fold if op's mask is constant
+ auto emptyMask = rewriter.create<arith::ConstantOp>(
+ loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+ *foldedIntraVectorOffset);
+ }
+ } else {
+ mask = dynamicallyInsertSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+ linearizedInfo.intraDataOffset, origElements);
}
Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
- if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ if (foldedIntraVectorOffset) {
+ if (isUnalignedEmulation) {
+ result =
+ staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+ *foldedIntraVectorOffset, origElements);
+ }
+ } else {
+ auto resultVector = rewriter.create<arith::ConstantOp>(
+ loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+ result = dynamicallyExtractSubVector(
+ rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+ linearizedInfo.intraDataOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -659,10 +693,9 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;
- auto maxIntraVectorOffset =
- foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+ auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
- llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+ llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,54 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
// CHECK: %[[C2:.+]] = arith.constant 2 : index
// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %cst = arith.constant dense<0> : vector<3x3xi2>
+ %c2 = arith.constant 2 : index
+ %mask = vector.constant_mask [3] : vector<3xi1>
+ %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+ memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of drive-by comments - addressing those will help reviewing. Will take a proper look later, thanks!
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, | |||
Location loc, Value mask, | |||
int origElements, int scale, | |||
int intraDataOffset = 0) { | |||
assert(intraDataOffset < scale && "intraDataOffset must be less than scale"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not clear from the method name ...
What are origElements
, scale
and intraDataOffset
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
origElements
is the number of elements of the subbyte vectorscale
isbyte-emulated element type size / original element type size
. For example, if the original elem type isi2
, then thescale
issizeof(i8)/sizeof(i2) = 4
.intraDataOffset
is the element offset into the emulated byte. For example, to extract the second slice ofvector<3xi2>
out from avector<3x3xi2>
(here we assume the subbyte type elements are stored in memory packed), we would need to load 2 bytes (the first and second byte), and extract bit[7, 14)
out from it. so the first 3 elements are irrelevant in this case, henceintraDataOffset == 3
in such case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, this was not clear to me at all :)
I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.
Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's my attempt to improve the comments and input variable names:
Please let me know whether that makes sense to you, and any feedback is welcome.
Note, I've also created these two:
- [mlir][vector] Restrict narrow-type-emulation patterns #115612
- Bugs in patterns under
populateVectorNarrowTypeEmulationPatterns
(1D vs 2D) #115653
(again, your feedback would be appreciated). Last, but not least, this example seems off. In particular:
/// [Comment from Andrzej] 6 elements
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// [Comment from Andrzej] 8 elements != 6 + 1
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
Shouldn't the padded mask be: %mask = [0, 1, 1, 0, 0, 0, 0]
(7 elements)?
Btw, thanks so much for working on this - your efforts are truly appreciated! Please don’t let my comments (and appetite to improve things overall) give you any other impression 😅.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right! Here I just exposed some intermediate calculating details to the comment, as in this case scale == 2
so making the padded mask a multiple of scale
in the intermediary result is easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slightly updated the comment part. can you take a look at it again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, I'm just a bit confused by the current state of affairs in this area 😅
I really appreciate you working on this and don't want to block this, but we should also make also think of improving maintainability of this. Just to be on the constructive side of things (please review):
I will take another look at this later.
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, | |||
Location loc, Value mask, | |||
int origElements, int scale, | |||
int intraDataOffset = 0) { | |||
assert(intraDataOffset < scale && "intraDataOffset must be less than scale"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, this was not clear to me at all :)
I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.
Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly nits. I'm a bit distracted with the other PRs :(
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
be34311
to
d6437e9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot to @banach-space who provides so many useful review comments and keeps the code/comments healthy! Also thanks to @lialan who helps the effort!
The PR itself looks good to me, just few comments.
IIRC, it adds the support when the source mask is from vector.constant_mask
; the case that the source mask is from vector.create_mask
is not supported yet. Am I correct? If so, could you add such information to the PR description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please wait for @banach-space before landing the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for bearing with me!
In `staticallyExtractSubvector`, When the extracting slice is the same as source vector, do not need to emit `vector.extract_strided_slice`. This fixes the lit test case `@vector_store_i4` in `mlir\test\Dialect\Vector\vector-emulate-narrow-type.mlir`, where converting from `vector<8xi4>` to `vector<4xi8>` does not need slice extraction. The issue was introduced in #113411 and #115070, CI failure link: https://buildkite.com/llvm-project/github-pull-requests/builds/118845 This PR does not include a lit test case because it is a fix and the above mentioned `@vector_store_i4` test actually tests the mechanism. Signed-off-by: Alan Li <[email protected]>
Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.
the dynamic parts are constructed by multiple
vector.extract
andvector.insert
to rearrange the original mask/passthru vector, asvector.insert_strided_slice
andvector.extract_strided_slice
only take static offsets and indices.Note: currently only supporting
vector.maskedload
with masks created byvector.constant_mask
.vector.create_mask
is currently not working.