Skip to content

[MLIR] support dynamic indexing of vector.maskedload in VectorEmulateNarrowTypes #115070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 12, 2024

Conversation

lialan
Copy link
Member

@lialan lialan commented Nov 5, 2024

Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.

the dynamic parts are constructed by multiple vector.extract and vector.insert to rearrange the original mask/passthru vector, as vector.insert_strided_slice and vector.extract_strided_slice only take static offsets and indices.

Note: currently only supporting vector.maskedload with masks created by vector.constant_mask. vector.create_mask is currently not working.

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.

the dynamic parts are constructed by multiple vector.extract and vector.insert to rearrange the original mask/passthru vector, as vector.insert_strided_slice and vector.extract_strided_slice only take static offsets and indices.


Full diff: https://github.com/llvm/llvm-project/pull/115070.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+66-33)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+51)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
                                                   int intraDataOffset = 0) {
+  assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
   auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
   return dest;
 }
 
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+                                        TypedValue<VectorType> source,
+                                        Value dest, OpFoldResult destOffsetVar,
+                                        int64_t length) {
+  assert(length > 0 && "length must be greater than 0");
+  for (int i = 0; i < length; ++i) {
+    Value insertLoc =
+        1 == 0
+            ? destOffsetVar.dyn_cast<Value>()
+            : rewriter.create<arith::AddIOp>(
+                  loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+                  rewriter.create<arith::ConstantIndexOp>(loc, i));
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+  }
+  return dest;
+}
+
 /// Returns the op sequence for an emulated sub-byte data type vector load.
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
   return rewriter.create<vector::BitCastOp>(
       loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
       newLoad);
-};
+}
 
 namespace {
 
@@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
-                            *foldedIntraVectorOffset);
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
+    Value passthru = op.getPassThru();
+
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
     auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
-    Value passthru = op.getPassThru();
-    if (isUnalignedEmulation) {
-      // create an empty vector of the new type
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
-                                           *foldedIntraVectorOffset);
+    auto emptyVector = rewriter.create<arith::ConstantOp>(
+        loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        passthru = staticallyInsertSubvector(
+            rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+      }
+    } else {
+      passthru = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+          emptyVector, linearizedInfo.intraDataOffset, origElements);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +606,36 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    if (isUnalignedEmulation) {
-      auto newSelectMaskType =
-          VectorType::get(numElements * scale, rewriter.getI1Type());
-      // TODO: can fold if op's mask is constant
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
-      mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
-                                       *foldedIntraVectorOffset);
+    auto newSelectMaskType =
+        VectorType::get(numElements * scale, rewriter.getI1Type());
+    // TODO: try to fold if op's mask is constant
+    auto emptyMask = rewriter.create<arith::ConstantOp>(
+        loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+                                         *foldedIntraVectorOffset);
+      }
+    } else {
+      mask = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+          linearizedInfo.intraDataOffset, origElements);
     }
 
     Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
-    if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -659,10 +693,9 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    auto maxIntraVectorOffset =
-        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
     auto numElements =
-        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,54 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %mask = vector.constant_mask [3] : vector<3xi1>
+  %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+    memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-mlir-vector

Author: lialan (lialan)

Changes

Based on existing emulating scheme, this patch expands to support dynamic indexing by dynamically create intermediate new mask, new pass thru vector and dynamically insert the result into destination vector.

the dynamic parts are constructed by multiple vector.extract and vector.insert to rearrange the original mask/passthru vector, as vector.insert_strided_slice and vector.extract_strided_slice only take static offsets and indices.


Full diff: https://github.com/llvm/llvm-project/pull/115070.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+66-33)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+51)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index f169dab3bdd9af..56273ac2899d7e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
                                                   Location loc, Value mask,
                                                   int origElements, int scale,
                                                   int intraDataOffset = 0) {
+  assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
   auto numElements = (intraDataOffset + origElements + scale - 1) / scale;
 
   Operation *maskOp = mask.getDefiningOp();
@@ -182,6 +183,25 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
   return dest;
 }
 
+/// Inserts a 1-D subvector into a 1-D `dest` vector at index `offset`.
+static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
+                                        TypedValue<VectorType> source,
+                                        Value dest, OpFoldResult destOffsetVar,
+                                        int64_t length) {
+  assert(length > 0 && "length must be greater than 0");
+  for (int i = 0; i < length; ++i) {
+    Value insertLoc =
+        1 == 0
+            ? destOffsetVar.dyn_cast<Value>()
+            : rewriter.create<arith::AddIOp>(
+                  loc, rewriter.getIndexType(), destOffsetVar.dyn_cast<Value>(),
+                  rewriter.create<arith::ConstantIndexOp>(loc, i));
+    auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
+    dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
+  }
+  return dest;
+}
+
 /// Returns the op sequence for an emulated sub-byte data type vector load.
 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
 /// The load location is given by `base` and `linearizedIndices`, and the
@@ -199,7 +219,7 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
   return rewriter.create<vector::BitCastOp>(
       loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
       newLoad);
-};
+}
 
 namespace {
 
@@ -546,29 +566,30 @@ struct ConvertVectorMaskedLoad final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
-    FailureOr<Operation *> newMask =
-        getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale,
-                            *foldedIntraVectorOffset);
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
+    FailureOr<Operation *> newMask = getCompressedMaskOp(
+        rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
     if (failed(newMask))
       return failure();
 
+    Value passthru = op.getPassThru();
+
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
     auto loadType = VectorType::get(numElements, newElementType);
     auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
 
-    Value passthru = op.getPassThru();
-    if (isUnalignedEmulation) {
-      // create an empty vector of the new type
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
-      passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
-                                           *foldedIntraVectorOffset);
+    auto emptyVector = rewriter.create<arith::ConstantOp>(
+        loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        passthru = staticallyInsertSubvector(
+            rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset);
+      }
+    } else {
+      passthru = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
+          emptyVector, linearizedInfo.intraDataOffset, origElements);
     }
     auto newPassThru =
         rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
@@ -585,23 +606,36 @@ struct ConvertVectorMaskedLoad final
         rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);
 
     Value mask = op.getMask();
-    if (isUnalignedEmulation) {
-      auto newSelectMaskType =
-          VectorType::get(numElements * scale, rewriter.getI1Type());
-      // TODO: can fold if op's mask is constant
-      auto emptyVector = rewriter.create<arith::ConstantOp>(
-          loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
-      mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
-                                       *foldedIntraVectorOffset);
+    auto newSelectMaskType =
+        VectorType::get(numElements * scale, rewriter.getI1Type());
+    // TODO: try to fold if op's mask is constant
+    auto emptyMask = rewriter.create<arith::ConstantOp>(
+        loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
+                                         *foldedIntraVectorOffset);
+      }
+    } else {
+      mask = dynamicallyInsertSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(mask), emptyMask,
+          linearizedInfo.intraDataOffset, origElements);
     }
 
     Value result =
         rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
-
-    if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result =
+            staticallyExtractSubvector(rewriter, loc, op.getType(), result,
+                                       *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      result = dynamicallyExtractSubVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraDataOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -659,10 +693,9 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    auto maxIntraVectorOffset =
-        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
+    auto maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
     auto numElements =
-        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraDataOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 0cecaddc5733e2..6a10a2f9ed32fe 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -183,3 +183,54 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto
 // CHECK: %[[C2:.+]] = arith.constant 2 : index
 // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
 // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>
+// -----
+
+func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %idx: index) -> vector<3xi2> {
+  %0 = memref.alloc() : memref<3x3xi2>
+  %cst = arith.constant dense<0> : vector<3x3xi2>
+  %c2 = arith.constant 2 : index
+  %mask = vector.constant_mask [3] : vector<3xi1>
+  %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru :
+    memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2>
+  return %1 : vector<3xi2>
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> ((s0 * 3 + 2) floordiv 4)>
+// CHECK: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 3 - ((s0 * 3 + 2) floordiv 4) * 4 + 2)>
+// CHECK: func @vector_maskedload_i2_dynamic_indexing_mixed(
+// CHECK-SAME: %[[PTH:.+]]: vector<3xi2>, %[[IDX:.+]]: index) -> vector<3xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
+// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
+// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
+// CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[EX1:.+]] = vector.extract %[[PTH]][0] : i2 from vector<3xi2>
+// CHECK: %[[IN1:.+]] = vector.insert %[[EX1]], %[[ZERO]] [%[[LINEAR2]]] : i2 into vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[INCIDX:.+]] = arith.addi %[[LINEAR2]], %[[C1]] : index
+// CHECK: %[[EX2:.+]] = vector.extract %[[PTH]][1] : i2 from vector<3xi2>
+// CHECK: %[[IN2:.+]] = vector.insert %[[EX2]], %[[IN1]] [%[[INCIDX]]] : i2 into vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[INCIDX2:.+]] = arith.addi %[[LINEAR2]], %[[C2]] : index
+// CHECK: %[[EX3:.+]] = vector.extract %[[PTH]][2] : i2 from vector<3xi2>
+// CHECK: %[[IN3:.+]] = vector.insert %[[EX3]], %[[IN2]] [%[[INCIDX2]]] : i2 into vector<8xi2>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[IN3]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BITCAST]]
+// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST1:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[EX4:.+]] = vector.extract %[[MASK]][0] : i1 from vector<3xi1>
+// CHECK: %[[IN4:.+]] = vector.insert %[[EX4]], %[[CST1]] [%[[LINEAR2]]] : i1 into vector<8xi1>
+// CHECK: %[[EX5:.+]] = vector.extract %[[MASK]][1] : i1 from vector<3xi1>
+// CHECK: %[[IN5:.+]] = vector.insert %[[EX5]], %[[IN4]] [%[[INCIDX]]] : i1 into vector<8xi1>
+// CHECK: %[[EX6:.+]] = vector.extract %[[MASK]][2] : i1 from vector<3xi1>
+// CHECK: %[[IN6:.+]] = vector.insert %[[EX6]], %[[IN5]] [%[[INCIDX2]]] : i1 into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[IN6]], %[[BITCAST2]], %[[IN3]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[CST2:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EX7:.+]] = vector.extract %[[SELECT]][%[[LINEAR2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN7:.+]] = vector.insert %[[EX7]], %[[CST2]] [0] : i2 into vector<3xi2>
+// CHECK: %[[EX8:.+]] = vector.extract %[[SELECT]][%[[INCIDX]]] : i2 from vector<8xi2>
+// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
+// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
+// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>

Copy link

github-actions bot commented Nov 5, 2024

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of drive-by comments - addressing those will help reviewing. Will take a proper look later, thanks!

@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not clear from the method name ...

What are origElements, scale and intraDataOffset?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • origElements is the number of elements of the subbyte vector
  • scale is byte-emulated element type size / original element type size. For example, if the original elem type is i2, then the scale is sizeof(i8)/sizeof(i2) = 4.
  • intraDataOffset is the element offset into the emulated byte. For example, to extract the second slice of vector<3xi2> out from a vector<3x3xi2> (here we assume the subbyte type elements are stored in memory packed), we would need to load 2 bytes (the first and second byte), and extract bit [7, 14) out from it. so the first 3 elements are irrelevant in this case, hence intraDataOffset == 3 in such case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this was not clear to me at all :)

I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.

Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my attempt to improve the comments and input variable names:

Please let me know whether that makes sense to you, and any feedback is welcome.

Note, I've also created these two:

(again, your feedback would be appreciated). Last, but not least, this example seems off. In particular:

/// [Comment from Andrzej] 6 elements
///  %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// [Comment from Andrzej] 8 elements != 6 + 1
///   %mask = [0, 1, 1, 0, 0, 0, 0, 0]

Shouldn't the padded mask be: %mask = [0, 1, 1, 0, 0, 0, 0] (7 elements)?

Btw, thanks so much for working on this - your efforts are truly appreciated! Please don’t let my comments (and appetite to improve things overall) give you any other impression 😅.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right! Here I just exposed some intermediate calculating details to the comment, as in this case scale == 2 so making the padded mask a multiple of scale in the intermediary result is easier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly updated the comment part. can you take a look at it again?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I'm just a bit confused by the current state of affairs in this area 😅

I really appreciate you working on this and don't want to block this, but we should also make also think of improving maintainability of this. Just to be on the constructive side of things (please review):

I will take another look at this later.

@@ -53,6 +53,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
int intraDataOffset = 0) {
assert(intraDataOffset < scale && "intraDataOffset must be less than scale");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this was not clear to me at all :)

I was trying to understand all of this a bit better and am just thinking that this logic needs TLC. The comment for this method needs updating to capture the info that you shared above. I think that it would also be good to provide more descriptive argument names.

Now, I appreciate that it wasn't you who wrote this to begin with and updating this shouldn't be a blocker for this PR. Some help would be appreciated. Also, I want to help:

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits. I'm a bit distracted with the other PRs :(

@lialan lialan force-pushed the lialan/dynamic_masked_load branch from be34311 to d6437e9 Compare November 11, 2024 20:45
@lialan lialan requested a review from banach-space November 11, 2024 20:46
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot to @banach-space who provides so many useful review comments and keeps the code/comments healthy! Also thanks to @lialan who helps the effort!

The PR itself looks good to me, just few comments.

IIRC, it adds the support when the source mask is from vector.constant_mask; the case that the source mask is from vector.create_mask is not supported yet. Am I correct? If so, could you add such information to the PR description?

@lialan lialan requested review from hanhanW and dcaballe November 12, 2024 01:15
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please wait for @banach-space before landing the PR.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for bearing with me!

@hanhanW hanhanW merged commit c3c3ccc into llvm:main Nov 12, 2024
8 checks passed
hanhanW pushed a commit that referenced this pull request Nov 12, 2024
In `staticallyExtractSubvector`, When the extracting slice is the same
as source vector, do not need to emit `vector.extract_strided_slice`.

This fixes the lit test case `@vector_store_i4` in
`mlir\test\Dialect\Vector\vector-emulate-narrow-type.mlir`, where
converting from `vector<8xi4>` to `vector<4xi8>` does not need slice
extraction.

The issue was introduced in #113411 and #115070, CI failure link:
https://buildkite.com/llvm-project/github-pull-requests/builds/118845

This PR does not include a lit test case because it is a fix and the
above mentioned `@vector_store_i4` test actually tests the mechanism.

Signed-off-by: Alan Li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants