Skip to content

[mlir][vector] Restrict narrow-type-emulation patterns #115612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 9, 2024

All patterns in populateVectorNarrowTypeEmulationPatterns currently
assume a 1-D vector load/store rather than an n-D vector load/store.
This assumption is evident in ConvertVectorTransferRead, for example,
here (extracted from ConvertVectorTransferRead):

auto newRead = rewriter.create<vector::TransferReadOp>(
    loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
    getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
    newPadding);

auto bitCast = rewriter.create<vector::BitCastOp>(
    loc, VectorType::get(numElements * scale, oldElementType), newRead);

Both invocations of VectorType::get() here generate a 1-D vector.

Attempts to use these patterns with more generic cases, such as 2-D
vectors, fail. For example, trying to cast the following 2-D case to
i32:

func.func @vector_maskedload_2d_i8_negative(
  %idx1: index,
  %idx2: index,
  %num_elems: index,
  %passthru: vector<2x4xi8>) -> vector<2x4xi8> {

    %0 = memref.alloc() : memref<3x4xi8>
    %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
    return %1 : vector<2x4xi8>

}

For example, casting to i32 produces:

error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
         ^

Instead of reworking these patterns (that's going to require much more
effort), I’ve marked them as 1-D only and extended
"TestEmulateNarrowTypePass" with an option to disable the Memref type
converter - that's to be able to add negative tests (otherwise, the type
converter throws an error we can't really test for). While not ideal,
this workaround should suit a test pass.

@llvmbot
Copy link
Member

llvmbot commented Nov 9, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector][nfc] Add tests + update docs for narrow-type emulation
  • fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation
  • [mlir][vector] Restrict narrow-type-emulation patterns

Patch is 27.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115612.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+53-14)
  • (added) mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir (+112)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+133-9)
  • (modified) mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp (+10-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 58841f29698e0d..91da9bc9c7f8a0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1,11 +1,19 @@
-//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
-//-*-===//
+//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to emulate
+// narrow types that are not supported by the target hardware, e.g. i4, using
+// wider types, e.g. i8.
+//
+/// Currently, only power-of-two integer types are supported. These are
+/// converted to wider integers that are either 8 bits wide or wider.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -217,6 +225,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
   matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getValueToStore().getType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -283,6 +295,10 @@ struct ConvertVectorMaskedStore final
   matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getValueToStore().getType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getValueToStore().getType().getElementType();
@@ -315,23 +331,34 @@ struct ConvertVectorMaskedStore final
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
 
     // Load the whole data and use arith.select to handle the corner cases.
-    // E.g., given these input values:
     //
-    //   %mask = [0, 1, 1, 1, 1, 1, 0, 0]
-    //   %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
-    //   %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
+    // As an example, for this masked store:
+    //
+    //   vector.maskedstore %0[%c0, %c0], %mask, %val_to_store
     //
-    // we'll have
+    // and given these input i4 values:
     //
-    //    expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
+    //   %mask = [1, 1, 1, 1, 1, 0, 0, 0]                     (8 * i1)
+    //   %0[%c0, %c0] =
+    //      [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]          (8 * i4)
+    //   %val_to_store =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]          (8 * i4)
     //
-    //    %new_mask = [1, 1, 1, 0]
-    //    %maskedload = [0x12, 0x34, 0x56, 0x00]
-    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
-    //    %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
-    //    %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
+    // we'll have the following i4 output:
     //
-    // Using the new mask to store %packed_data results in expected output.
+    //    expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8]
+    //
+    // Emulating the above using i8 will give:
+    //
+    //    %compressed_mask = [1, 1, 1, 0]                     (4 * i1)
+    //    %maskedload = [0x12, 0x34, 0x56, 0x00]              (4 * i8)
+    //    %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4)
+    //    %select_using_shifted_mask =
+    //      [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0]          (8 * i4)
+    //    %packed_data = [0x9A, 0xBC, 0xD6, 0x00]             (4 * i8)
+    //
+    // Using the compressed mask to store %packed_data results in expected
+    // output.
     FailureOr<Operation *> newMask =
         getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
     if (failed(newMask))
@@ -372,6 +399,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
   matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
@@ -473,6 +504,10 @@ struct ConvertVectorMaskedLoad final
   matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
     Type oldElementType = op.getType().getElementType();
@@ -624,6 +659,10 @@ struct ConvertVectorTransferRead final
   matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
+    if (op.getVectorType().getRank() != 1)
+      return rewriter.notifyMatchFailure(op,
+                                         "only 1-D vectors are supported ATM");
+
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
     Type oldElementType = op.getType().getElementType();
diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..30ce13e8169c47
--- /dev/null
+++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,112 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
+
+// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
+// insted of 1-D vectors. That's currently not supported.
+
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
+    return %1 : vector<2x4xi8>
+}
+
+// No support for loading 2D vectors - expect no conversions
+//  CHECK-LABEL: func @vector_load_2d_i8_negative
+//        CHECK:   memref.alloc() : memref<3x4xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
+func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
+    %c0 = arith.constant 0 : i4
+    %0 = memref.alloc() : memref<3x8xi4>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
+      memref<3x8xi4>, vector<2x8xi4>
+    return %1 : vector<2x8xi4>
+}
+//  CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
+//        CHECK:  memref.alloc() : memref<3x8xi4>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
+    %0 = memref.alloc() : memref<3x4xi8>
+    %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
+    %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
+      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
+    return %1 : vector<2x4xi8>
+}
+
+//  CHECK-LABEL: func @vector_maskedload_2d_i8_negative
+//        CHECK:  memref.alloc() : memref<3x4xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
+func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
+    %0 = memref.alloc() : memref<8x8x16xi4>
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c8 = arith.constant 8 : index
+    %cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
+    %cst_2 = arith.constant dense<0> : vector<16xi4>
+    %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
+    %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
+    %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1>
+    %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4>
+    %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4>
+    return %63 : vector<8x8x16xi4>
+}
+
+//  CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
+//        CHECK: memref.alloc() : memref<8x8x16xi4>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
+func.func @vector_store_2d_i8_negative(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xi8>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
+    return
+}
+
+//  CHECK-LABEL: func @vector_store_2d_i8_negative
+//        CHECK: memref.alloc() : memref<4x8xi8>
+//    CHECK-NOT: i32
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
+func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
+  %0 = memref.alloc() : memref<3x8xi8>
+  %mask = vector.create_mask %arg2 : vector<8xi1>
+  vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8>
+  return
+}
+
+//  CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
+//        CHECK: memref.alloc() : memref<3x8xi8>
+//    CHECK-NOT: i32
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index cba299b2a1d956..5e139b04d7ee6f 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -1,6 +1,10 @@
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
 // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
+///----------------------------------------------------------------------------------------
+/// vector.load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+///----------------------------------------------------------------------------------------
+
 func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
     %c0 = arith.constant 0 : i4
     %0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedload
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.create_mask %arg3 : vector<4xi1>
@@ -190,7 +202,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
 
 // -----
 
-func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
+func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> {
     %0 = memref.alloc() : memref<3x4xi8>
     %mask = vector.constant_mask [2] : vector<4xi1>
     %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
@@ -198,7 +210,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
     return %1 : vector<4xi8>
 }
 // Expect no conversions, i8 is supported.
-//      CHECK: func @vector_cst_maskedload_i8(
+//      CHECK: func @vector_maskedload_i8_constant_mask(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
 // CHECK-NEXT:   %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -208,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
 // CHECK-NEXT:   return
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
-//      CHECK32: func @vector_cst_maskedload_i8(
+//      CHECK32: func @vector_maskedload_i8_constant_mask(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -224,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
-func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
+func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> {
     %0 = memref.alloc() : memref<3x8xi4>
     %cst = arith.constant dense<0> : vector<3x8xi4>
     %mask = vector.constant_mask [4] : vector<8xi1>
@@ -234,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
     return %2 : vector<3x8xi4>
 }
 //  CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
-//      CHECK: func @vector_cst_maskedload_i4(
+//      CHECK: func @vector_maskedload_i4_constant_mask(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -248,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 //      CHECK:   %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
 
 //  CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
-//      CHECK32: func @vector_cst_maskedload_i4(
+//      CHECK32: func @vector_maskedload_i4_constant_mask(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.extract -> vector.masked_load
+///----------------------------------------------------------------------------------------
+
 func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
     %0 = memref.alloc() : memref<8x8x16xi4>
     %c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.store
+///----------------------------------------------------------------------------------------
+
 func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
     %0 = memref.alloc() : memref<4x8xi8>
     vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.maskedstore
+///----------------------------------------------------------------------------------------
+
 func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
   %0 = memref.alloc() : memref<3x8xi8>
   %mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,14 +493,68 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
 
 // -----
 
-func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
+func.func @vector_maskedstore_i4(
+  %idx1: index,
+  %idx2: index,
+  %num_elements_to_store: index,
+  %value: vector<8xi4>) {
+
+    %0 = memref.alloc() : memref<3x8xi4>
+    %mask = vector.create_mask %num_elements_to_store : vector<8xi1>
+    vector.maskedstore %0[%idx1, %idx2], %mask, %value :
+      memref<3x8xi4>, vector<8xi1>, vector<8xi4>
+    return
+}
+// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
+
+// CHECK-LABEL:   func.func @vector_maskedstore_i4(
+// CHECK-SAME:      %[[IDX_1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[IDX_2:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME:      %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
+// CHECK:           %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK:           %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
+// CHECK:           %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
+// CHECK:           %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
+// CHECK:           %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
+// CHECK:           %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
+// CHECK:           %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
+// CHECK:           %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
+// CHECK:           %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
+// CHECK:           vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
+
+// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
+
+// CHECK32...
[truncated]

@banach-space banach-space changed the title andrzej/emulate narrow type update 2 [mlir][vector] Restrict narrow-type-emulation patterns Nov 9, 2024
@@ -225,6 +225,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (op.getValueToStore().getType().getRank() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug. Multi-dimensional vector.store should be supported, but there might be a bug...

See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store. So during the emulation the destination memref and the source vector get converted to 1D before the store.

I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug.

This is a bug :) In fact, one of many. Please see the summary ;-)

My PR is effectively a bug report. In fact, I should've started with a bug report. This is now reported here:

See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store.

Yes, two things need to happen: linearization + bitcasting. The former (linearization) seems to work fine only for source/destination memref(s). For vectors, it appears to be broken. For reference, see the reproduces that I added as tests.

I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores.

IIUC, we agree that there are multiple bugs here? This should be fixed, but in the meantime, lets document these "discoveries" through:

How does it sound?


As a side note ...

So during the emulation the destination memref and the source vector get converted to 1D before the store.

From what I can tell, dealing with n-D vectors is going to be tricky and might take some time (especially when masking is involved). I'd start by making sure 3 basic cases are covered:

  • 1-D memref + 1-D vector,
  • 2-D memref + 1-D vector,
  • 2-D memref + 2-D vector.

Top 2 seem to be already supported. The bottom one is not. I haven't thought of n-D cases yet (n > 2), but perhaps that's trivial once 2-D is fully supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your taxonomy is right. I think supporting multi dim vectors is much more involved. So with that context looking back at your change, this makes total sense!

All patterns under `populateVectorNarrowTypeEmulationPatterns` assume a
1-D vector load/store (as opposed to n-D vector load/store). This is
evident from `ConvertVectorTransferRead`, e.g., here:

```cpp
auto newRead = rewriter.create<vector::TransferReadOp>(
    loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
    getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
    newPadding);

auto bitCast = rewriter.create<vector::BitCastOp>(
    loc, VectorType::get(numElements * scale, oldElementType), newRead);
```

Attempts to use these patterns in more generic cases fail, as shown
below:

```mlir
func.func @vector_maskedload_2d_i8_negative(
  %idx1: index,
  %idx2: index,
  %num_elems: index,
  %passthru: vector<2x4xi8>) -> vector<2x4xi8> {

    %0 = memref.alloc() : memref<3x4xi8>
    %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
      memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
    return %1 : vector<2x4xi8>

}
```

For example, casting to i32 produces:
```bash
error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
    %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
         ^
```

Instead of reworking these patterns (that's going to require much more
effort), I’ve marked them as 1-D only and extended
"TestEmulateNarrowTypePass" with an option to disable the Memref type
converter - that's to be able to add negative tests (otherwise, the type
converter throws an error we can't really test for). While not ideal,
this workaround should suit a test pass.
Update tests that were still loading/storing 1-D vectors
@lialan
Copy link
Member

lialan commented Nov 11, 2024

Feels like it could be much easier to add a new pass to decompose higher dimensional vector loads/stores into 1d loads and stores before this pass. Existing code in VectorEmulateNarrowTypes is already hard to swallow.

@banach-space
Copy link
Contributor Author

Feels like it could be much easier to add a new pass to decompose higher dimensional vector loads/stores into 1d loads and stores before this pass. Existing code in VectorEmulateNarrowTypes is already hard to swallow.

Sure, but that won't change the fact that the current logic is buggy in the case that I highlighted:

2-D memref + 2-D vector.

As in, even if such a pass was added, we'd still want to make sure that this code fails gracefully for cases are not supported.

@banach-space
Copy link
Contributor Author

Ping @lialan

@banach-space banach-space merged commit e458434 into llvm:main Nov 12, 2024
8 checks passed
@banach-space banach-space deleted the andrzej/emulate_narrow_type_update_2 branch January 13, 2025 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants