Skip to content

[MLIR] Fix VectorEmulateNarrowType constant op mask bug #116064

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 15, 2024

Conversation

lialan
Copy link
Member

@lialan lialan commented Nov 13, 2024

This commit adds support for handling mask constants generated by the arith.constant op in the VectorEmulateNarrowType pattern. Previously, this pattern would not match due to the lack of mask constant handling in getCompressedMaskOp.

The changes include:

  1. Updating getCompressedMaskOp to recognize and handle arith.constant ops as mask value sources.

  2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset.

Limitations:

  • The arith.constant op can only have 1-dimensional constant values.

Resolves: #115742

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

This commit adds support for handling mask constants generated by the arith.constant op in the VectorEmulateNarrowType pattern. Previously, this pattern would not match due to the lack of mask constant handling in getCompressedMaskOp.

The changes include:

  1. Updating getCompressedMaskOp to recognize and handle arith.constant ops as mask value sources.

  2. Handling cases where the mask is not aligned with the emulated load width. The compressed mask is adjusted to account for the offset.

Limitations:

  • The arith.constant op can only have 1-dimensional constant values.

Resolves: #115742


Full diff: https://github.com/llvm/llvm-project/pull/116064.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+44-2)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+171)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+24)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index eb4ce24548e603..f2e9ae18d3371c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -70,7 +70,9 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   Operation *maskOp = mask.getDefiningOp();
   SmallVector<vector::ExtractOp, 2> extractOps;
   // Finding the mask creation operation.
-  while (maskOp && !isa<vector::CreateMaskOp, vector::ConstantMaskOp>(maskOp)) {
+  while (maskOp &&
+         !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
+             maskOp)) {
     if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
       maskOp = extractOp.getVector().getDefiningOp();
       extractOps.push_back(extractOp);
@@ -78,7 +80,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
   }
   auto createMaskOp = dyn_cast_or_null<vector::CreateMaskOp>(maskOp);
   auto constantMaskOp = dyn_cast_or_null<vector::ConstantMaskOp>(maskOp);
-  if (!createMaskOp && !constantMaskOp)
+  auto constantOp = dyn_cast_or_null<arith::ConstantOp>(maskOp);
+  if (!createMaskOp && !constantMaskOp && !constantOp)
     return failure();
 
   // Computing the "compressed" mask. All the emulation logic (i.e. computing
@@ -129,6 +132,45 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
       auto denseAttr = DenseElementsAttr::get(newMaskType, newMaskValues);
       newMask = rewriter.create<arith::ConstantOp>(loc, newMaskType, denseAttr);
     }
+  } else if (constantOp) {
+    assert(shape.size() == 1 && "expected 1-D mask");
+    // Rearrange the original mask values to cover the whole potential loading
+    // region. For example, in the case of using byte-size for emulation, given
+    // the following mask:
+    //
+    //   %mask = vector.constant_mask [0, 1, 0, 1, 0, 0] : vector<6xi2>
+    //
+    // with front offset of 1, the mask will be padded zeros in the front and
+    // back so that its length is multiple of `scale` (and the total coverage
+    // size is mulitiple of bytes):
+    //   %new_mask = vector.constant_mask [0, 0, 1, 0, 1, 0, 0, 0] :
+    //   vector<8xi2>
+    //
+    // The %new_mask is now aligned with the effective loading area and can now
+    // be compressed.
+    SmallVector<bool> maskValues(intraDataOffset, false);
+    if (auto denseAttr =
+            mlir::dyn_cast<DenseIntElementsAttr>(constantOp.getValue())) {
+      for (auto value : denseAttr.getValues<bool>()) {
+        maskValues.push_back(value);
+      }
+      while (maskValues.size() < numElements * scale) {
+        maskValues.push_back(false);
+      }
+    } else {
+      return failure();
+    }
+    // Compressing by combining every `scale` elements:
+    SmallVector<bool> compressedMaskValues;
+    for (size_t i = 0; i < maskValues.size(); i += scale) {
+      bool combinedValue = false;
+      for (int j = 0; j < scale; ++j) {
+        combinedValue |= maskValues[i + j];
+      }
+      compressedMaskValues.push_back(combinedValue);
+    }
+    newMask = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(newMaskType, compressedMaskValues));
   }
 
   while (!extractOps.empty()) {
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..359162d76219f4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,174 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this example, only emit 1 atomic store
+// CHECK: func @vector_store_i2_single_atomic(
+// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
+
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// -----
+
+func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
+  %0 = memref.alloc() : memref<3x5xi2>
+  %mask = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru :
+    memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
+  return %1 : vector<5xi2>
+}
+
+// CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
+// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
+
+// CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
+// CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
+
+// Emulated masked load from alloc:
+// CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
+// CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
+
+// Select from emulated loaded vector and passthru vector:
+// TODO: fold this part if possible.
+// CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[BCAST2]], %[[CST2]]
+// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
+// CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
+// CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
+// CHECK: return %[[EXTRACT]] : vector<5xi2>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 034bd47f6163e6..19edc9ddcaf2b4 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -624,3 +624,27 @@ func.func @vector_maskedstore_i4_constant_mask(
 // CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
 // CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
 // CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
+func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vector<8xi4> {
+  %0 = memref.alloc() : memref<3x8xi4>
+  %cst = arith.constant dense<0> : vector<8xi4>
+  %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+  %c0 = arith.constant 0 : index
+  %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
+    memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
+  return %1 : vector<8xi4>
+}
+
+// CHECK: func @vector_maskedload_i4_arith_constant(
+// CHECK-SAME:   %[[PASSTHRU:[a-zA-Z0-9]+]]: vector<8xi4>) -> vector<8xi4> {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<24xi8>
+// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
+// CHECK: %[[CST:.+]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[PASSTHRU]] : vector<8xi4> to vector<4xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C0]], %[[C0]]], %[[MASK]], %[[BITCAST]] : memref<24xi8>, vector<8xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[BITCAST2]], %[[PASSTHRU]] : vector<4xi1>, vector<8xi4>
+// CHECK: return %[[SELECT]] : vector<8xi4>

@lialan lialan force-pushed the lialan/constant_op branch 2 times, most recently from d3a5396 to 5be2464 Compare November 13, 2024 16:43
@hanhanW hanhanW requested a review from banach-space November 13, 2024 17:46
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, just some nits about style/comments/tests.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

I think that the logic in getCompressedMaskOp could benefit from a high level refactor. In particular, why not replace:

f (createMaskOp) {
} else if (constantMaskOp) {
} else if (constantOp) {
}

with (pseudo syntax):

TypeSwitch<Operation *op>(maskOp).
Case<CreateMaskOp>[&](){}
Case<ConstantMaskOp>[&](){}
Case<ConstantOp>[&](){}

and then every case should be implemented in a dedicated hook.

More comments inline.

@lialan lialan force-pushed the lialan/constant_op branch from 5be2464 to 98f8b9e Compare November 14, 2024 00:54
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a first pass for now. Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

I only see tests for maskedload - how about maskedstore?

Copy link

github-actions bot commented Nov 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@lialan lialan force-pushed the lialan/constant_op branch from dbb73f6 to 7ed57ef Compare November 14, 2024 22:27
@lialan
Copy link
Member Author

lialan commented Nov 14, 2024

@banach-space Thanks for the review! I've updated according to your suggestions. I also added a masked store test case to illustrate the effect. Given that maskedstore does not support unaligned emulation yet (which is on my todo list), I only added 1 of it (the aligned case).

I think what really helps is to use more meaningful names in the test, it is quite self-explanatory!

@lialan lialan requested a review from banach-space November 14, 2024 22:33
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a couple comments about comments. I'll help land the PR once we get an approval from @banach-space

@lialan lialan force-pushed the lialan/constant_op branch 2 times, most recently from 9344055 to 574752d Compare November 15, 2024 14:23
@banach-space
Copy link
Contributor

I think what really helps is to use more meaningful names in the test, it is quite self-explanatory!

Thank you for this comment and for updating the tests 🙏🏻. This makes reviewing and maintenance so much easier - our future selves will be very grateful! 😊

I’ve left a few small suggestions, but otherwise, this looks good to me (LGTM).


Given the growing complexity of this logic, I have a couple of high-level observations:

  1. Should "vector-emulate-narrow-type.mlir" be renamed to "vector-emulate-narrow-type-aligned.mlir"? Should the cases in "vector-emulate-narrow-type-unaligned.mlir" essentially be variations of the cases in "vector-emulate-narrow-type-aligned.mlir"?

  2. All the existing cases in "vector-emulate-narrow-type-unaligned.mlir" test i2, while the new cases test i4. In a PR like this, where the key "novelty" is support for arith.constant (i.e. a new Op), I’d normally expect the data types to remain consistent with the other tests. This way, it’s easier to isolate and understand the differences introduced by the new changes.

I’m not suggesting any additional large changes or refactors in this PR - @lialan has already done a great job! But I think the patterns/trends highlighted above emerge naturally and could serve as a helpful guiding principle for future PRs. If others agree, of course. 😊

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some final nits

This commit adds support for handling mask constants generated by the
`arith.constant` op in the `VectorEmulateNarrowType` pattern. Previously, this
pattern would not match due to the lack of mask constant handling in
`getCompressedMaskOp`.

The changes include:

1. Updating `getCompressedMaskOp` to recognize and handle `arith.constant` ops as
   mask value sources.

2. Handling cases where the mask is not aligned with the emulated load width.
   The compressed mask is adjusted to account for the offset.

Limitations:
- The arith.constant op can only have 1-dimensional constant values.

Resolves: llvm#115742

Signed-off-by: Alan Li <[email protected]>
@lialan lialan force-pushed the lialan/constant_op branch from 574752d to 309abfe Compare November 15, 2024 15:37
@hanhanW hanhanW merged commit ef92aba into llvm:main Nov 15, 2024
8 checks passed
@lialan lialan deleted the lialan/constant_op branch November 15, 2024 18:08

// -----

func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>) -> vector<5xi2> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it in the previous review. Could you remind me why it is labeled i4 but the test is loading i2 types?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hanhanW copy/paste mis match, I should have double checked it. I will update it later in a batch refactoring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in the VectorEmulateNarrowType logic
5 participants