Skip to content

[mlir][vector][nfc] Add tests + update docs for narrow-type emulation #115460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 8, 2024

The documentation for narrow-type emulation was sparse, so I’ve expanded
it with additional clarifications (e.g., specifying that the example
discusses i4 -> i8 emulation).

I also noticed some inconsistencies in testing for narrow-type
emulation, with several cases covered only for "loading" and missing for
"storing." To address this, I’ve:

  • Added comments in the test file for easier reference,
  • Added the missing tests for vector.maskedstore.

Additionally, I’ve renamed tests for vector.masked{load|store} for
clarity:

  • @vector_cst_maskedload_i8 -> @vector_maskedload_i8_constant_mask.

This makes it easier to contrast with similar functions, such as
@vector_maskedload_i8.

Lastly, I’ve added a high-level comment in VectorEmulateNarrowType.cpp
to clarify the overall design and intent of the file.

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

The documentation for narrow-type emulation is a bit inaccurate. In
particular, we don't really support/generate masks like this:

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

I updated the comment for ConvertVectorMaskedStore accordingly. I also
added a few clarification (e.g. that the comment is discussing i4 -> i8
emulation).

Separately, I've noticed inconsistency in testing for
narrow-type-emulation. In particular, there's a few cases that are
tested for "loading" and which are missing for "storing". I've added

  • comments in the test file so that it's easy to see what's tested,
  • missing tests for vector.maskedstor.

Finally, I've added a top level comment in VectorEmulateNarrowType.cpp
so that the overall intent and design are clearer.


Full diff: https://github.com/llvm/llvm-project/pull/115460.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-13)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+126)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..b29617c09ea4ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
 
     // Load the whole data and use arith.select to handle the corner cases.
-    // E.g., given these input values:
+    // E.g., given these input i4 values:
+    //
+    //   %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+    //
+    //   %mask = [1, 1, 1, 1, 1, 1, 1, 0]                     (8 * i1)
+    //   %0[%c0, %c0] = 
+    //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
+    //   %val_to_store =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]          (8 * i4)
     //
-    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
-    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+    // we'll have the following i4 output:
     //
-    // we'll have
+    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
     //
-    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+    // Emulating the above using i8 will give:
     //
-    //    %new_mask = [1, 1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
-    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
-    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+    //    %compressed_mask = [1, 1, 1, 1]                     (4 * i1)
+    //    %maskedload = [0x12, 0x34, 0x56, 0x78]              (4 * i8)
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+    //    %select_using_shifted_mask =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]          (8 * i4)
+    //    %packed_data = [0x9A, 0xBC, 0xDE, 0xF8]             (4 * i8)
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..c98b4dd50a7028 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
 func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
     %c0 = arith.constant 0 : i4
     %0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
     %0 = memref.alloc() : memref<8x8x16xi4>
     %c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
 func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
     %0 = memref.alloc() : memref<4x8xi8>
     vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 
 // -----
 
+func.func @vector_maskedstore_i4(
+  %idx1: index,
+  %idx2: index,
+  %num_elements_to_store: index,
+  %value: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+    vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
 func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
 // CHECK32:        %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
 // CHECK32:        %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
 // CHECK32:        vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i4(
+  %idx_1: index,
+  %idx_2: index,
+  %val_to_store: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.constant_mask [4] : vector<8xi1>
+    vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+
+// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

The documentation for narrow-type emulation is a bit inaccurate. In
particular, we don't really support/generate masks like this:

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

I updated the comment for ConvertVectorMaskedStore accordingly. I also
added a few clarification (e.g. that the comment is discussing i4 -> i8
emulation).

Separately, I've noticed inconsistency in testing for
narrow-type-emulation. In particular, there's a few cases that are
tested for "loading" and which are missing for "storing". I've added

  • comments in the test file so that it's easy to see what's tested,
  • missing tests for vector.maskedstor.

Finally, I've added a top level comment in VectorEmulateNarrowType.cpp
so that the overall intent and design are clearer.


Full diff: https://github.com/llvm/llvm-project/pull/115460.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-13)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+126)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..b29617c09ea4ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
 
     // Load the whole data and use arith.select to handle the corner cases.
-    // E.g., given these input values:
+    // E.g., given these input i4 values:
+    //
+    //   %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+    //
+    //   %mask = [1, 1, 1, 1, 1, 1, 1, 0]                     (8 * i1)
+    //   %0[%c0, %c0] = 
+    //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
+    //   %val_to_store =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]          (8 * i4)
     //
-    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
-    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+    // we'll have the following i4 output:
     //
-    // we'll have
+    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
     //
-    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+    // Emulating the above using i8 will give:
     //
-    //    %new_mask = [1, 1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
-    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
-    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+    //    %compressed_mask = [1, 1, 1, 1]                     (4 * i1)
+    //    %maskedload = [0x12, 0x34, 0x56, 0x78]              (4 * i8)
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+    //    %select_using_shifted_mask =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]          (8 * i4)
+    //    %packed_data = [0x9A, 0xBC, 0xDE, 0xF8]             (4 * i8)
     //
     // Using the new mask to store %packed_data results in expected output.
     FailureOr<Operation *> newMask =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..c98b4dd50a7028 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
 func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
     %c0 = arith.constant 0 : i4
     %0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
     %0 = memref.alloc() : memref<8x8x16xi4>
     %c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
 func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
     %0 = memref.alloc() : memref<4x8xi8>
     vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 
 // -----
 
+func.func @vector_maskedstore_i4(
+  %idx1: index,
+  %idx2: index,
+  %num_elements_to_store: index,
+  %value: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+    vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
 func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
 // CHECK32:        %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
 // CHECK32:        %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
 // CHECK32:        vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i4(
+  %idx_1: index,
+  %idx_2: index,
+  %val_to_store: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %cst = arith.constant dense<0> : vector<3x8xi4>
+    %mask = vector.constant_mask [4] : vector<8xi1>
+    vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+
+// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-LABEL:   func.func @vector_cst_maskedstore_i4(
+// CHECK32-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32:           %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32:           %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32:           %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32:           %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32:           %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK32:           %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

Copy link

github-actions bot commented Nov 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

The documentation for narrow-type emulation is a bit inaccurate. In
particular, we don't really support/generate masks like this:

  %mask = [0, 1, 1, 1, 1, 1, 0, 0]

I updated the comment for `ConvertVectorMaskedStore` accordingly. I also
added a few clarification (e.g. that the comment is discussing i4 -> i8
emulation).

Separately, I've noticed inconsistency in testing for
narrow-type-emulation. In particular, there's a few cases that are
tested for "loading" and which are missing for "storing". I've added
  * comments in the test file so that it's easy to see what's tested,
  * missing tests for `vector.maskedstor`.

Finally, I've added a top level comment in VectorEmulateNarrowType.cpp
so that the overall intent and design are clearer.
…ulation

* Fix failing test
* Tweak/fix the comment
* Rename: @vector_cst_maskedload_i8 -> @vector_cst_maskedload_i8_constant_mask (same for other similar tests)
@dcaballe
Copy link
Contributor

dcaballe commented Nov 8, 2024

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411.
@lialan may want to clarify.

@banach-space
Copy link
Contributor Author

banach-space commented Nov 8, 2024

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411.

Hm, I was trying to come up with an example but am a bit stumped 😅 . For instance:

    %1 = vector.maskedload %0[%c2], %mask, %passthru : memref<10xi2>, vector<8xi1>, vector<8xi2> into vector<8xi2>

Even though i2 isn’t likely byte-aligned, wouldn’t the emulated load begin with the byte containing the first value to load? If so, for this initial mask for i2 (loading at %c2):

%mask_for_i2 = [1, 1, 1, 1, 1, 1, 1, 0]

we should generate the following mask for i8 (loading at %c0):

%mask_for_i8 = [1, 1]

In other words, I think we almost always start with [1, 1, ...] (at least one leading "1") and compress to [1, ...] (at least one leading "1"). Let me know if I’m missing something obvious 🤔


EDIT 9/11/24

I just discovered another point relevant to this discussion—2D vectors (where I had been assuming only 1D vectors):

func.func @vector_maskedload_2d_i8_negative(
  %idx1: index,
  %idx2: index,
  %num_elems: index,
  %passthru: vector<2x4xi8>) -> vector<2x4xi8> {

    %0 = memref.alloc() : memref<3x4xi8>
    %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
    return %1 : vector<2x4xi8>

}

TBH, it's not entirely clear would should happen here 😅 In fact, that's not supported at all:

within split at file.mlir:32 offset :11:10: error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
         ^
within split at file.mlir:32 offset :11:10: note: see current operation: %11 = "vector.bitcast"(%arg3) : (vector<2x4xi8>) -> vector<2xi32>

Turns out there isn’t a single test for the 2D case, and indeed, all patterns under

fail. I’ve prepared a patch to document that:

For simplicity’s sake, I’m inclined to enforce 1D support only for now.

// narrow types that are not supported by the target hardware, e.g. i4, using
// wider types, e.g. i8.
//
/// Currently, only power-of-two integer types are supported. These are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is actually nothing fundamental about power of 2 here. It is an implementation detail and just NYI. Maybe might be better to make that clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, perhaps we can explicitly say that the non-power-of-two types are not implemented yet in the description. It looks more clear for newcomers for the narrow type emulations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a TODO at the top - clear indication for anyone brave enough to tackle this :)

@lialan
Copy link
Member

lialan commented Nov 11, 2024

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411. @lialan may want to clarify.

@dcaballe @banach-space the preceding zeros will be used when a memref's innermost dim is not strictly byte aligned. This could happen not only as input but after some transformations such as tiling. So it is definitely needed (but best to avoid as it sometimes emit poor code).

@banach-space
Copy link
Contributor Author

@dcaballe @banach-space the preceding zeros will be used when a memref's innermost dim is not strictly byte aligned.

Thank you for explaining; that makes sense! I’ll restore the original logic in the comment. This only occurs in cases with sub-byte types, correct? It’d be helpful to call that out explicitly.

Also, considering the definitions of vector.constant_mask and vector.create_mask, how would we construct the corresponding mask? I assume it would be created via arith.constant, something like this:

func.func @vector_maskedload_i4_constant_mask(%passthru: vector<8xi4>) -> vector<8xi4> {
  %0 = memref.alloc() : memref<3x8xi4>
  %cst = arith.constant dense<0> : vector<8xi4>
  %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
  %c0 = arith.constant 0 : index
  %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
    memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
  return %1 : vector<8xi4>
}

This example demonstrates the scenario described in the comment and could also help clarify expected behavior. Unfortunately, it doesn’t work as expected:

This suggests that while the sub-byte alignment case is theoretically possible, it’s likely that no one is actively relying on it, given that it’s currently broken.

Since everyone is stretched thin, I’d propose proactively disabling and documenting these broken paths to prevent confusion or missteps down the road. Hence e.g.:

Just to clarify, I merely want to improve the overall health of this area. And, a bit selfishly, to make reviewing easier :) Like I said elsewhere, your contributions in this area @lialan are much appreciated!

@banach-space
Copy link
Contributor Author

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411.

👍🏻 As pointed out in my previous reply, I wasn't able to test that case. But I see what you and @lialan mean and I restored that part of the comment. Still, using an example that doesn't work to document behaviour is not ideal 😅 (perhaps there's a version that works?)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving the doc and test, really appreciate!

// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need to escape [ because we do not capture % in variables. It is one of the pros of the trick, and it looks cleaner. Can you update it and fix other checks in the new lit tests?

Suggested change
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, very nice, thanks for pointing this out 🙏🏻

// narrow types that are not supported by the target hardware, e.g. i4, using
// wider types, e.g. i8.
//
/// Currently, only power-of-two integer types are supported. These are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, perhaps we can explicitly say that the non-power-of-two types are not implemented yet in the description. It looks more clear for newcomers for the narrow type emulations.

@hanhanW
Copy link
Contributor

hanhanW commented Nov 11, 2024

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411.

just for future reference, the comment about it is in #113411 (comment) . I marked the conversation unresolved, so people can see it easily in the PR.

@banach-space
Copy link
Contributor Author

Thank you all for the discussion! I will land this later today if there's no new comments.

%mask = [0, 1, 1, 1, 1, 1, 0, 0]

The preceding zeros are common for the unaligned cases introduced here: #113411.

just for future reference, the comment about it is in #113411 (comment) . I marked the conversation unresolved, so people can see it easily in the PR.

Thanks for digging this out, and apologies for not commenting when this was initially discussed. For now, I’ve created #115742 and referenced it in the comment. Hopefully, a volunteer can look into fixing it 😅 - if not, we may need to revisit.

@banach-space banach-space merged commit 6fe7ad8 into llvm:main Nov 12, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/emulate_narrow_type_update branch November 12, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants