-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[mlir][vector][nfc] Add tests + update docs for narrow-type emulation #115460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector][nfc] Add tests + update docs for narrow-type emulation #115460
Conversation
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThe documentation for narrow-type emulation is a bit inaccurate. In %mask = [0, 1, 1, 1, 1, 1, 0, 0] I updated the comment for Separately, I've noticed inconsistency in testing for
Finally, I've added a top level comment in VectorEmulateNarrowType.cpp Full diff: https://github.com/llvm/llvm-project/pull/115460.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..b29617c09ea4ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
// Load the whole data and use arith.select to handle the corner cases.
- // E.g., given these input values:
+ // E.g., given these input i4 values:
+ //
+ // %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+ //
+ // %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1)
+ // %0[%c0, %c0] =
+ // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %val_to_store =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
//
- // %mask = [0, 1, 1, 1, 1, 1, 0, 0]
- // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
- // %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+ // we'll have the following i4 output:
//
- // we'll have
+ // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
//
- // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+ // Emulating the above using i8 will give:
//
- // %new_mask = [1, 1, 1, 0]
- // %maskedload = [0x12, 0x34, 0x56, 0x00]
- // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
- // %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
- // %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+ // %compressed_mask = [1, 1, 1, 1] (4 * i1)
+ // %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8)
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %select_using_shifted_mask =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4)
+ // %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8)
//
// Using the new mask to store %packed_data results in expected output.
FailureOr<Operation *> newMask =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..c98b4dd50a7028 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// -----
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
// -----
+func.func @vector_maskedstore_i4(
+ %idx1: index,
+ %idx2: index,
+ %num_elements_to_store: index,
+ %value: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+ vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL: func.func @vector_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32-LABEL: func.func @vector_maskedstore_i4(
+// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i4(
+ %idx_1: index,
+ %idx_2: index,
+ %val_to_store: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.constant_mask [4] : vector<8xi1>
+ vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+
+// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-LABEL: func.func @vector_cst_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4(
+// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesThe documentation for narrow-type emulation is a bit inaccurate. In %mask = [0, 1, 1, 1, 1, 1, 0, 0] I updated the comment for Separately, I've noticed inconsistency in testing for
Finally, I've added a top level comment in VectorEmulateNarrowType.cpp Full diff: https://github.com/llvm/llvm-project/pull/115460.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..b29617c09ea4ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
// Load the whole data and use arith.select to handle the corner cases.
- // E.g., given these input values:
+ // E.g., given these input i4 values:
+ //
+ // %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
+ //
+ // %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1)
+ // %0[%c0, %c0] =
+ // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %val_to_store =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
//
- // %mask = [0, 1, 1, 1, 1, 1, 0, 0]
- // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
- // %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+ // we'll have the following i4 output:
//
- // we'll have
+ // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
//
- // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+ // Emulating the above using i8 will give:
//
- // %new_mask = [1, 1, 1, 0]
- // %maskedload = [0x12, 0x34, 0x56, 0x00]
- // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
- // %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
- // %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+ // %compressed_mask = [1, 1, 1, 1] (4 * i1)
+ // %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8)
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %select_using_shifted_mask =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4)
+ // %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8)
//
// Using the new mask to store %packed_data results in expected output.
FailureOr<Operation *> newMask =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..c98b4dd50a7028 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// -----
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
// -----
+func.func @vector_maskedstore_i4(
+ %idx1: index,
+ %idx2: index,
+ %num_elements_to_store: index,
+ %value: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+ vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL: func.func @vector_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32-LABEL: func.func @vector_maskedstore_i4(
+// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
+
+// -----
+
func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
+
+// -----
+
+func.func @vector_cst_maskedstore_i4(
+ %idx_1: index,
+ %idx_2: index,
+ %val_to_store: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %cst = arith.constant dense<0> : vector<3x8xi4>
+ %mask = vector.constant_mask [4] : vector<8xi1>
+ vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+
+// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK-LABEL: func.func @vector_cst_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4(
+// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
+// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
+// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
+// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
+// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
+// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
+// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
+// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
The documentation for narrow-type emulation is a bit inaccurate. In particular, we don't really support/generate masks like this: %mask = [0, 1, 1, 1, 1, 1, 0, 0] I updated the comment for `ConvertVectorMaskedStore` accordingly. I also added a few clarification (e.g. that the comment is discussing i4 -> i8 emulation). Separately, I've noticed inconsistency in testing for narrow-type-emulation. In particular, there's a few cases that are tested for "loading" and which are missing for "storing". I've added * comments in the test file so that it's easy to see what's tested, * missing tests for `vector.maskedstor`. Finally, I've added a top level comment in VectorEmulateNarrowType.cpp so that the overall intent and design are clearer.
99deeab
to
8a9abf6
Compare
…ulation * Fix failing test * Tweak/fix the comment * Rename: @vector_cst_maskedload_i8 -> @vector_cst_maskedload_i8_constant_mask (same for other similar tests)
Hm, I was trying to come up with an example but am a bit stumped 😅 . For instance: %1 = vector.maskedload %0[%c2], %mask, %passthru : memref<10xi2>, vector<8xi1>, vector<8xi2> into vector<8xi2> Even though i2 isn’t likely byte-aligned, wouldn’t the emulated load begin with the byte containing the first value to load? If so, for this initial mask for i2 (loading at %c2): %mask_for_i2 = [1, 1, 1, 1, 1, 1, 1, 0] we should generate the following mask for %mask_for_i8 = [1, 1] In other words, I think we almost always start with [1, 1, ...] (at least one leading "1") and compress to [1, ...] (at least one leading "1"). Let me know if I’m missing something obvious 🤔 EDIT 9/11/24 I just discovered another point relevant to this discussion—2D vectors (where I had been assuming only 1D vectors): func.func @vector_maskedload_2d_i8_negative(
%idx1: index,
%idx2: index,
%num_elems: index,
%passthru: vector<2x4xi8>) -> vector<2x4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
%1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
return %1 : vector<2x4xi8>
} TBH, it's not entirely clear would should happen here 😅 In fact, that's not supported at all: within split at file.mlir:32 offset :11:10: error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
%1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
^
within split at file.mlir:32 offset :11:10: note: see current operation: %11 = "vector.bitcast"(%arg3) : (vector<2x4xi8>) -> vector<2xi32> Turns out there isn’t a single test for the 2D case, and indeed, all patterns under fail. I’ve prepared a patch to document that: For simplicity’s sake, I’m inclined to enforce 1D support only for now. |
// narrow types that are not supported by the target hardware, e.g. i4, using | ||
// wider types, e.g. i8. | ||
// | ||
/// Currently, only power-of-two integer types are supported. These are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that there is actually nothing fundamental about power of 2 here. It is an implementation detail and just NYI. Maybe might be better to make that clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, perhaps we can explicitly say that the non-power-of-two types are not implemented yet in the description. It looks more clear for newcomers for the narrow type emulations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a TODO at the top - clear indication for anyone brave enough to tackle this :)
@dcaballe @banach-space the preceding zeros will be used when a memref's innermost dim is not strictly byte aligned. This could happen not only as input but after some transformations such as tiling. So it is definitely needed (but best to avoid as it sometimes emit poor code). |
…type emulation Refine the comment
Thank you for explaining; that makes sense! I’ll restore the original logic in the comment. This only occurs in cases with sub-byte types, correct? It’d be helpful to call that out explicitly. Also, considering the definitions of vector.constant_mask and vector.create_mask, how would we construct the corresponding mask? I assume it would be created via func.func @vector_maskedload_i4_constant_mask(%passthru: vector<8xi4>) -> vector<8xi4> {
%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<8xi4>
%mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1>
%c0 = arith.constant 0 : index
%1 = vector.maskedload %0[%c0, %c0], %mask, %passthru :
memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
return %1 : vector<8xi4>
} This example demonstrates the scenario described in the comment and could also help clarify expected behavior. Unfortunately, it doesn’t work as expected: This suggests that while the sub-byte alignment case is theoretically possible, it’s likely that no one is actively relying on it, given that it’s currently broken. Since everyone is stretched thin, I’d propose proactively disabling and documenting these broken paths to prevent confusion or missteps down the road. Hence e.g.: Just to clarify, I merely want to improve the overall health of this area. And, a bit selfishly, to make reviewing easier :) Like I said elsewhere, your contributions in this area @lialan are much appreciated! |
👍🏻 As pointed out in my previous reply, I wasn't able to test that case. But I see what you and @lialan mean and I restored that part of the comment. Still, using an example that doesn't work to document behaviour is not ideal 😅 (perhaps there's a version that works?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for improving the doc and test, really appreciate!
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { | ||
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> | ||
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1> | ||
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't need to escape [
because we do not capture %
in variables. It is one of the pros of the trick, and it looks cleaner. Can you update it and fix other checks in the new lit tests?
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]] | |
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, very nice, thanks for pointing this out 🙏🏻
// narrow types that are not supported by the target hardware, e.g. i4, using | ||
// wider types, e.g. i8. | ||
// | ||
/// Currently, only power-of-two integer types are supported. These are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, perhaps we can explicitly say that the non-power-of-two types are not implemented yet in the description. It looks more clear for newcomers for the narrow type emulations.
just for future reference, the comment about it is in #113411 (comment) . I marked the conversation unresolved, so people can see it easily in the PR. |
…narrow-type emulation Final tweaks
Thank you all for the discussion! I will land this later today if there's no new comments.
Thanks for digging this out, and apologies for not commenting when this was initially discussed. For now, I’ve created #115742 and referenced it in the comment. Hopefully, a volunteer can look into fixing it 😅 - if not, we may need to revisit. |
The documentation for narrow-type emulation was sparse, so I’ve expanded
it with additional clarifications (e.g., specifying that the example
discusses
i4
->i8
emulation).I also noticed some inconsistencies in testing for narrow-type
emulation, with several cases covered only for "loading" and missing for
"storing." To address this, I’ve:
vector.maskedstore
.Additionally, I’ve renamed tests for
vector.masked{load|store}
forclarity:
@vector_cst_maskedload_i8
->@vector_maskedload_i8_constant_mask
.This makes it easier to contrast with similar functions, such as
@vector_maskedload_i8
.Lastly, I’ve added a high-level comment in VectorEmulateNarrowType.cpp
to clarify the overall design and intent of the file.