Skip to content

[mlir][vector] Restrict narrow-type-emulation patterns #115612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getValueToStore().getType().getRank() != 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug. Multi-dimensional vector.store should be supported, but there might be a bug...

See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store. So during the emulation the destination memref and the source vector get converted to 1D before the store.

I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bug.

This is a bug :) In fact, one of many. Please see the summary ;-)

My PR is effectively a bug report. In fact, I should've started with a bug report. This is now reported here:

See comment below. It is explicitly written for multi-dimensional loads. The only general way to emulate sub-byte loads is to linearize the memrefs and do a linear store.

Yes, two things need to happen: linearization + bitcasting. The former (linearization) seems to work fine only for source/destination memref(s). For vectors, it appears to be broken. For reference, see the reproduces that I added as tests.

I am not opposed to having this, but seems too big a hammer. There is a bug here for multi-dimensional stores.

IIUC, we agree that there are multiple bugs here? This should be fixed, but in the meantime, lets document these "discoveries" through:

How does it sound?


As a side note ...

So during the emulation the destination memref and the source vector get converted to 1D before the store.

From what I can tell, dealing with n-D vectors is going to be tricky and might take some time (especially when masking is involved). I'd start by making sure 3 basic cases are covered:

  • 1-D memref + 1-D vector,
  • 2-D memref + 1-D vector,
  • 2-D memref + 2-D vector.

Top 2 seem to be already supported. The bottom one is not. I haven't thought of n-D cases yet (n > 2), but perhaps that's trivial once 2-D is fully supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your taxonomy is right. I think supporting multi dim vectors is much more involved. So with that context looking back at your change, this makes total sense!

return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Expand Down Expand Up @@ -283,6 +288,11 @@ struct ConvertVectorMaskedStore final
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getValueToStore().getType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Expand Down Expand Up @@ -372,6 +382,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Expand Down Expand Up @@ -473,6 +488,11 @@ struct ConvertVectorMaskedLoad final
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Expand Down Expand Up @@ -624,6 +644,11 @@ struct ConvertVectorTransferRead final
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// See #115653
if (op.getVectorType().getRank() != 1)
return rewriter.notifyMatchFailure(op,
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Expand Down
111 changes: 111 additions & 0 deletions mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s

// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
// insted of 1-D vectors. That's currently not supported.

///----------------------------------------------------------------------------------------
/// vector.load
///----------------------------------------------------------------------------------------

func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
return %1 : vector<2x4xi8>
}

// No support for loading 2D vectors - expect no conversions
// CHECK-LABEL: func @vector_load_2d_i8_negative
// CHECK: memref.alloc() : memref<3x4xi8>
// CHECK-NOT: i32

// -----

///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------

func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
memref<3x8xi4>, vector<2x8xi4>
return %1 : vector<2x8xi4>
}
// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
// CHECK: memref.alloc() : memref<3x8xi4>
// CHECK-NOT: i32

// -----

///----------------------------------------------------------------------------------------
/// vector.maskedload
///----------------------------------------------------------------------------------------

func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
return %1 : vector<2x4xi8>
}

// CHECK-LABEL: func @vector_maskedload_2d_i8_negative
// CHECK: memref.alloc() : memref<3x4xi8>
// CHECK-NOT: i32

// -----

///----------------------------------------------------------------------------------------
/// vector.extract -> vector.masked_load
///----------------------------------------------------------------------------------------

func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c8 = arith.constant 8 : index
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
%cst_2 = arith.constant dense<0> : vector<8x16xi4>
%27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
%50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4>
%63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4>
return %63 : vector<8x8x16xi4>
}

// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
// CHECK: memref.alloc() : memref<8x8x16xi4>
// CHECK-NOT: i32

// -----

///----------------------------------------------------------------------------------------
/// vector.store
///----------------------------------------------------------------------------------------

func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8>
return
}

// CHECK-LABEL: func @vector_store_2d_i8_negative
// CHECK: memref.alloc() : memref<4x8xi8>
// CHECK-NOT: i32

// -----

///----------------------------------------------------------------------------------------
/// vector.maskedstore
///----------------------------------------------------------------------------------------

func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1>
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8>
return
}

// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
// CHECK: memref.alloc() : memref<3x8xi8>
// CHECK-NOT: i32
11 changes: 10 additions & 1 deletion mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ struct TestEmulateNarrowTypePass
IntegerType::get(ty.getContext(), arithComputeBitwidth));
});

memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
// With the type converter enabled, we are effectively unable to write
// negative tests. This is a workaround specifically for negative tests.
if (!disableMemrefTypeConversion)
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);

ConversionTarget target(*ctx);
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
Expand Down Expand Up @@ -109,6 +113,11 @@ struct TestEmulateNarrowTypePass
Option<unsigned> arithComputeBitwidth{
*this, "arith-compute-bitwidth",
llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};

Option<bool> disableMemrefTypeConversion{
*this, "skip-memref-type-conversion",
llvm::cl::desc("disable memref type conversion (to test failures)"),
llvm::cl::init(false)};
};
} // namespace

Expand Down
Loading