-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[mlir][vector] Restrict narrow-type-emulation patterns #115612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Restrict narrow-type-emulation patterns #115612
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Patch is 27.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115612.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..91da9bc9c7f8a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -217,6 +225,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getValueToStore().getType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -283,6 +295,10 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getValueToStore().getType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -315,23 +331,34 @@ struct ConvertVectorMaskedStore final
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
// Load the whole data and use arith.select to handle the corner cases.
- // E.g., given these input values:
//
- // %mask = [0, 1, 1, 1, 1, 1, 0, 0]
- // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
- // %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+ // As an example, for this masked store:
+ //
+ // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
//
- // we'll have
+ // and given these input i4 values:
//
- // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+ // %mask = [1, 1, 1, 1, 1, 0, 0, 0] (8 * i1)
+ // %0[%c0, %c0] =
+ // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
+ // %val_to_store =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
//
- // %new_mask = [1, 1, 1, 0]
- // %maskedload = [0x12, 0x34, 0x56, 0x00]
- // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
- // %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
- // %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+ // we'll have the following i4 output:
//
- // Using the new mask to store %packed_data results in expected output.
+ // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
+ //
+ // Emulating the above using i8 will give:
+ //
+ // %compressed_mask = [1, 1, 1, 0] (4 * i1)
+ // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8)
+ // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
+ // %select_using_shifted_mask =
+ // [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4)
+ // %packed_data = [0x9A, 0xBC, 0xD6, 0x00] (4 * i8)
+ //
+ // Using the compressed mask to store %packed_data results in expected
+ // output.
FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
if (failed(newMask))
@@ -372,6 +399,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
@@ -473,6 +504,10 @@ struct ConvertVectorMaskedLoad final
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
@@ -624,6 +659,10 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (op.getVectorType().getRank() != 1)
+ return rewriter.notifyMatchFailure(op,
+ "only 1-D vectors are supported ATM");
+
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..30ce13e8169c47
--- /dev/null
+++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,112 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
+
+// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
+// insted of 1-D vectors. That's currently not supported.
+
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
+ return %1 : vector<2x4xi8>
+}
+
+// No support for loading 2D vectors - expect no conversions
+// CHECK-LABEL: func @vector_load_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x4xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
+func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
+ %c0 = arith.constant 0 : i4
+ %0 = memref.alloc() : memref<3x8xi4>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
+ memref<3x8xi4>, vector<2x8xi4>
+ return %1 : vector<2x8xi4>
+}
+// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
+// CHECK: memref.alloc() : memref<3x8xi4>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
+ %0 = memref.alloc() : memref<3x4xi8>
+ %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
+ %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+ memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
+ return %1 : vector<2x4xi8>
+}
+
+// CHECK-LABEL: func @vector_maskedload_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x4xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
+ %0 = memref.alloc() : memref<8x8x16xi4>
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c8 = arith.constant 8 : index
+ %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+ %cst_2 = arith.constant dense<0> : vector<16xi4>
+ %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+ %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+ %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+ %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+ %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+ return %63 : vector<8x8x16xi4>
+}
+
+// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
+// CHECK: memref.alloc() : memref<8x8x16xi4>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xi8>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+ return
+}
+
+// CHECK-LABEL: func @vector_store_2d_i8_negative
+// CHECK: memref.alloc() : memref<4x8xi8>
+// CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+ %0 = memref.alloc() : memref<3x8xi8>
+ %mask = vector.create_mask %arg2 : vector<8xi1>
+ vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+ return
+}
+
+// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
+// CHECK: memref.alloc() : memref<3x8xi8>
+// CHECK-NOT: i32
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..5e139b04d7ee6f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3 : vector<4xi1>
@@ -190,7 +202,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
// -----
-func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.constant_mask [2] : vector<4xi1>
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
@@ -198,7 +210,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
return %1 : vector<4xi8>
}
// Expect no conversions, i8 is supported.
-// CHECK: func @vector_cst_maskedload_i8(
+// CHECK: func @vector_maskedload_i8_constant_mask(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -208,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
// CHECK-NEXT: return
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-// CHECK32: func @vector_cst_maskedload_i8(
+// CHECK32: func @vector_maskedload_i8_constant_mask(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -224,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
// -----
-func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<3x8xi4>
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -234,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
return %2 : vector<3x8xi4>
}
// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-// CHECK: func @vector_cst_maskedload_i4(
+// CHECK: func @vector_maskedload_i4_constant_mask(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -248,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-// CHECK32: func @vector_cst_maskedload_i4(
+// CHECK32: func @vector_maskedload_i4_constant_mask(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
// -----
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
// -----
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// -----
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,14 +493,68 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
// -----
-func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+func.func @vector_maskedstore_i4(
+ %idx1: index,
+ %idx2: index,
+ %num_elements_to_store: index,
+ %value: vector<8xi4>) {
+
+ %0 = memref.alloc() : memref<3x8xi4>
+ %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+ vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+ memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+ return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL: func.func @vector_maskedstore_i4(
+// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32...
[truncated]
|
@@ -225,6 +225,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { | |||
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, | |||
ConversionPatternRewriter &rewriter) const override { | |||
|
|||
if (op.getValueToStore().getType().getRank() != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bug. Multi-dimensional vector.store should be supported, but there might be a bug...
See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store. So during the emulation the destination memref and the source vector get converted to 1D before the store.
I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bug.
This is a bug :) In fact, one of many. Please see the summary ;-)
My PR is effectively a bug report. In fact, I should've started with a bug report. This is now reported here:
See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store.
Yes, two things need to happen: linearization + bitcasting. The former (linearization) seems to work fine only for source/destination memref
(s). For vectors, it appears to be broken. For reference, see the reproduces that I added as tests.
I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores.
IIUC, we agree that there are multiple bugs here? This should be fixed, but in the meantime, lets document these "discoveries" through:
- a bug report Bugs in patterns under
populateVectorNarrowTypeEmulationPatterns
(1D vs 2D) #115653, - reproducers added in this PR,
- code (explicit pattern failures for code-paths that are known to be broken).
How does it sound?
As a side note ...
So during the emulation the destination memref and the source vector get converted to 1D before the store.
From what I can tell, dealing with n-D vectors is going to be tricky and might take some time (especially when masking is involved). I'd start by making sure 3 basic cases are covered:
- 1-D
memref
+ 1-Dvector
, - 2-D
memref
+ 1-Dvector
, - 2-D
memref
+ 2-Dvector
.
Top 2 seem to be already supported. The bottom one is not. I haven't thought of n-D cases yet (n > 2), but perhaps that's trivial once 2-D is fully supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your taxonomy is right. I think supporting multi dim vectors is much more involved. So with that context looking back at your change, this makes total sense!
All patterns under `populateVectorNarrowTypeEmulationPatterns` assume a 1-D vector load/store (as opposed to n-D vector load/store). This is evident from `ConvertVectorTransferRead`, e.g., here: ```cpp auto newRead = rewriter.create<vector::TransferReadOp>( loc, VectorType::get(numElements, newElementType), adaptor.getSource(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); auto bitCast = rewriter.create<vector::BitCastOp>( loc, VectorType::get(numElements * scale, oldElementType), newRead); ``` Attempts to use these patterns in more generic cases fail, as shown below: ```mlir func.func @vector_maskedload_2d_i8_negative( %idx1: index, %idx2: index, %num_elems: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> { %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1> %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8> return %1 : vector<2x4xi8> } ``` For example, casting to i32 produces: ```bash error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : ^ ``` Instead of reworking these patterns (that's going to require much more effort), I’ve marked them as 1-D only and extended "TestEmulateNarrowTypePass" with an option to disable the Memref type converter - that's to be able to add negative tests (otherwise, the type converter throws an error we can't really test for). While not ideal, this workaround should suit a test pass.
Update tests that were still loading/storing 1-D vectors
f9b2652
to
dfb8537
Compare
Feels like it could be much easier to add a new pass to decompose higher dimensional vector loads/stores into 1d loads and stores before this pass. Existing code in |
Sure, but that won't change the fact that the current logic is buggy in the case that I highlighted:
As in, even if such a pass was added, we'd still want to make sure that this code fails gracefully for cases are not supported. |
Ping @lialan |
All patterns in populateVectorNarrowTypeEmulationPatterns currently
assume a 1-D vector load/store rather than an n-D vector load/store.
This assumption is evident in ConvertVectorTransferRead, for example,
here (extracted from
ConvertVectorTransferRead
):Both invocations of
VectorType::get()
here generate a 1-D vector.Attempts to use these patterns with more generic cases, such as 2-D
vectors, fail. For example, trying to cast the following 2-D case to
i32
:For example, casting to i32 produces:
Instead of reworking these patterns (that's going to require much more
effort), I’ve marked them as 1-D only and extended
"TestEmulateNarrowTypePass" with an option to disable the Memref type
converter - that's to be able to add negative tests (otherwise, the type
converter throws an error we can't really test for). While not ideal,
this workaround should suit a test pass.