-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type #112214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Longsheng Mou (CoTinker) ChangesThis PR follows with #112104, using Full diff: https://github.com/llvm/llvm-project/pull/112214.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
index 4be0e06fe2a5e5..fddd7c51bfbc87 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
addConversion([this](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(convertTypes(ty.getInputs(), inputs)))
- return std::nullopt;
+ return nullptr;
SmallVector<Type> results;
if (failed(convertTypes(ty.getResults(), results)))
- return std::nullopt;
+ return nullptr;
return FunctionType::get(ty.getContext(), inputs, results);
});
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9efea066a03c85..1f461b9aa0abf0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -169,7 +169,7 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
std::is_same<OpTy, memref::AllocaOp>(),
"expected only memref::AllocOp or memref::AllocaOp");
auto currentType = cast<MemRefType>(op.getMemref().getType());
- auto newResultType = dyn_cast<MemRefType>(
+ auto newResultType = dyn_cast_or_null<MemRefType>(
this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
@@ -377,8 +377,8 @@ struct ConvertMemRefReinterpretCast final
LogicalResult
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ MemRefType newTy = dyn_cast_or_null<MemRefType>(
+ getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -466,7 +466,7 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
LogicalResult
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy = dyn_cast<MemRefType>(
+ MemRefType newTy = dyn_cast_or_null<MemRefType>(
getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
@@ -632,14 +632,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(ty, strides, offset)))
- return std::nullopt;
+ return nullptr;
if (!strides.empty() && strides.back() != 1)
- return std::nullopt;
+ return nullptr;
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
intTy.getSignedness());
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index bc4535f97acf04..49b71625291db9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
Type newElemTy = typeConverter.convertType(intTy);
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
return ty.cloneWith(std::nullopt, newElemTy);
});
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..024144337a31fb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --verify-diagnostics --split-input-file %s
+
+func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<40x40xi4>
+ // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
+ %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
+ %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
+ return %ld : i4
+}
+
+// -----
+
+func.func @alloc_non_contiguous() {
+ // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
+ return
+}
+
+// -----
+
+// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 540da239fced08..498f5d768e7358 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
// Expect no conversions.
func.func @memref_i8() -> i8 {
@@ -203,18 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
// -----
-
-func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<40x40xi4>
- // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
- %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
- %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
- return %ld : i4
-}
-
-// -----
-
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
@@ -540,16 +528,3 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
// CHECK32: return
-
-// -----
-
-!colMajor = memref<8x8xi4, strided<[1, 8]>>
-func.func @copy_distinct_layouts(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<8x8xi4>
- %arr2 = memref.alloc() : !colMajor
- // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
- memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
- %ld = memref.load %arr2[%c0, %c0] : !colMajor
- return %ld : i4
-}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
new file mode 100644
index 00000000000000..228e9a0bff7bcf
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" \
+// RUN: --split-input-file --verify-diagnostics %s
+
+// Make sure we do not crash on unsupported types.
+
+func.func @alloc_i128() {
+ // expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %m = memref.alloc() : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @load_i128(%m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
+ %v = memref.load %m[%c0] : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
+ memref.store %c1, %m[%c0] : memref<4xi128, 1>
+ return
+}
|
@llvm/pr-subscribers-mlir-memref Author: Longsheng Mou (CoTinker) ChangesThis PR follows with #112104, using Full diff: https://github.com/llvm/llvm-project/pull/112214.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
index 4be0e06fe2a5e5..fddd7c51bfbc87 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
addConversion([this](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(convertTypes(ty.getInputs(), inputs)))
- return std::nullopt;
+ return nullptr;
SmallVector<Type> results;
if (failed(convertTypes(ty.getResults(), results)))
- return std::nullopt;
+ return nullptr;
return FunctionType::get(ty.getContext(), inputs, results);
});
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9efea066a03c85..1f461b9aa0abf0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -169,7 +169,7 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
std::is_same<OpTy, memref::AllocaOp>(),
"expected only memref::AllocOp or memref::AllocaOp");
auto currentType = cast<MemRefType>(op.getMemref().getType());
- auto newResultType = dyn_cast<MemRefType>(
+ auto newResultType = dyn_cast_or_null<MemRefType>(
this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
@@ -377,8 +377,8 @@ struct ConvertMemRefReinterpretCast final
LogicalResult
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy =
- dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ MemRefType newTy = dyn_cast_or_null<MemRefType>(
+ getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -466,7 +466,7 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
LogicalResult
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- MemRefType newTy = dyn_cast<MemRefType>(
+ MemRefType newTy = dyn_cast_or_null<MemRefType>(
getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
@@ -632,14 +632,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(ty, strides, offset)))
- return std::nullopt;
+ return nullptr;
if (!strides.empty() && strides.back() != 1)
- return std::nullopt;
+ return nullptr;
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
intTy.getSignedness());
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
StridedLayoutAttr layoutAttr;
// If the offset is 0, we do not need a strided layout as the stride is
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
index bc4535f97acf04..49b71625291db9 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
Type newElemTy = typeConverter.convertType(intTy);
if (!newElemTy)
- return std::nullopt;
+ return nullptr;
return ty.cloneWith(std::nullopt, newElemTy);
});
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
new file mode 100644
index 00000000000000..024144337a31fb
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-unsupported.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --verify-diagnostics --split-input-file %s
+
+func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
+ %c0 = arith.constant 0 : index
+ %arr = memref.alloc() : memref<40x40xi4>
+ // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
+ %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
+ %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
+ return %ld : i4
+}
+
+// -----
+
+func.func @alloc_non_contiguous() {
+ // expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
+ return
+}
+
+// -----
+
+// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
+func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
+ return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 540da239fced08..498f5d768e7358 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
// Expect no conversions.
func.func @memref_i8() -> i8 {
@@ -203,18 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
// -----
-
-func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<40x40xi4>
- // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
- %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
- %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
- return %ld : i4
-}
-
-// -----
-
func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
@@ -540,16 +528,3 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
// CHECK32: return
-
-// -----
-
-!colMajor = memref<8x8xi4, strided<[1, 8]>>
-func.func @copy_distinct_layouts(%idx : index) -> i4 {
- %c0 = arith.constant 0 : index
- %arr = memref.alloc() : memref<8x8xi4>
- %arr2 = memref.alloc() : !colMajor
- // expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
- memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
- %ld = memref.load %arr2[%c0, %c0] : !colMajor
- return %ld : i4
-}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
new file mode 100644
index 00000000000000..228e9a0bff7bcf
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/emulate-wide-int-unsupported.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" \
+// RUN: --split-input-file --verify-diagnostics %s
+
+// Make sure we do not crash on unsupported types.
+
+func.func @alloc_i128() {
+ // expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
+ %m = memref.alloc() : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @load_i128(%m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
+ %v = memref.load %m[%c0] : memref<4xi128, 1>
+ return
+}
+
+// -----
+
+func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
+ %c0 = arith.constant 0 : index
+ // expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
+ memref.store %c1, %m[%c0] : memref<4xi128, 1>
+ return
+}
|
1f5ebbd
to
90dfdc7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for following up on this
7c1b27b
to
7337741
Compare
…rrow-type This PR follows with llvm#112104, using `nullptr` to indicate that type conversion failed and no fallback conversion should be attempted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
This PR follows with #112104, using
nullptr
to indicate that type conversion failed and no fallback conversion should be attempted.