Skip to content

Commit f5aee1f

Browse files
authored
[mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type (#112214)
This PR follows with #112104, using `nullptr` to indicate that type conversion failed and no fallback conversion should be attempted.
1 parent 90767bc commit f5aee1f

File tree

5 files changed

+57
-22
lines changed

5 files changed

+57
-22
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
4040
addConversion([this](FunctionType ty) -> std::optional<Type> {
4141
SmallVector<Type> inputs;
4242
if (failed(convertTypes(ty.getInputs(), inputs)))
43-
return std::nullopt;
43+
return nullptr;
4444

4545
SmallVector<Type> results;
4646
if (failed(convertTypes(ty.getResults(), results)))
47-
return std::nullopt;
47+
return nullptr;
4848

4949
return FunctionType::get(ty.getContext(), inputs, results);
5050
});

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
169169
std::is_same<OpTy, memref::AllocaOp>(),
170170
"expected only memref::AllocOp or memref::AllocaOp");
171171
auto currentType = cast<MemRefType>(op.getMemref().getType());
172-
auto newResultType = dyn_cast<MemRefType>(
173-
this->getTypeConverter()->convertType(op.getType()));
172+
auto newResultType =
173+
this->getTypeConverter()->template convertType<MemRefType>(
174+
op.getType());
174175
if (!newResultType) {
175176
return rewriter.notifyMatchFailure(
176177
op->getLoc(),
@@ -378,7 +379,7 @@ struct ConvertMemRefReinterpretCast final
378379
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
379380
ConversionPatternRewriter &rewriter) const override {
380381
MemRefType newTy =
381-
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
382+
getTypeConverter()->convertType<MemRefType>(op.getType());
382383
if (!newTy) {
383384
return rewriter.notifyMatchFailure(
384385
op->getLoc(),
@@ -466,8 +467,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
466467
LogicalResult
467468
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
468469
ConversionPatternRewriter &rewriter) const override {
469-
MemRefType newTy = dyn_cast<MemRefType>(
470-
getTypeConverter()->convertType(subViewOp.getType()));
470+
MemRefType newTy =
471+
getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
471472
if (!newTy) {
472473
return rewriter.notifyMatchFailure(
473474
subViewOp->getLoc(),
@@ -632,14 +633,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
632633
SmallVector<int64_t> strides;
633634
int64_t offset;
634635
if (failed(getStridesAndOffset(ty, strides, offset)))
635-
return std::nullopt;
636+
return nullptr;
636637
if (!strides.empty() && strides.back() != 1)
637-
return std::nullopt;
638+
return nullptr;
638639

639640
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640641
intTy.getSignedness());
641642
if (!newElemTy)
642-
return std::nullopt;
643+
return nullptr;
643644

644645
StridedLayoutAttr layoutAttr;
645646
// If the offset is 0, we do not need a strided layout as the stride is

mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
159159

160160
Type newElemTy = typeConverter.convertType(intTy);
161161
if (!newElemTy)
162-
return std::nullopt;
162+
return nullptr;
163163

164164
return ty.cloneWith(std::nullopt, newElemTy);
165165
});

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
203203

204204
// -----
205205

206-
207206
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
208207
%c0 = arith.constant 0 : index
209208
%arr = memref.alloc() : memref<40x40xi4>
@@ -543,13 +542,15 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
543542

544543
// -----
545544

546-
!colMajor = memref<8x8xi4, strided<[1, 8]>>
547-
func.func @copy_distinct_layouts(%idx : index) -> i4 {
548-
%c0 = arith.constant 0 : index
549-
%arr = memref.alloc() : memref<8x8xi4>
550-
%arr2 = memref.alloc() : !colMajor
551-
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
552-
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
553-
%ld = memref.load %arr2[%c0, %c0] : !colMajor
554-
return %ld : i4
545+
func.func @alloc_non_contiguous() {
546+
// expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
547+
%arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
548+
return
549+
}
550+
551+
// -----
552+
553+
// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
554+
func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
555+
return
555556
}

mlir/test/Dialect/MemRef/emulate-wide-int.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
1+
// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s \
2+
// RUN: --split-input-file --verify-diagnostics | FileCheck %s
23

34
// Expect no conversions, i32 is supported.
45
// CHECK-LABEL: func @memref_i32
@@ -15,6 +16,8 @@ func.func @memref_i32() {
1516
return
1617
}
1718

19+
// -----
20+
1821
// Expect no conversions, f64 is not an integer type.
1922
// CHECK-LABEL: func @memref_f32
2023
// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
@@ -30,6 +33,8 @@ func.func @memref_f32() {
3033
return
3134
}
3235

36+
// -----
37+
3338
// CHECK-LABEL: func @alloc_load_store_i64
3439
// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
3540
// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1>
@@ -45,6 +50,7 @@ func.func @alloc_load_store_i64() {
4550
return
4651
}
4752

53+
// -----
4854

4955
// CHECK-LABEL: func @alloc_load_store_i64_nontemporal
5056
// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
@@ -60,3 +66,30 @@ func.func @alloc_load_store_i64_nontemporal() {
6066
memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1>
6167
return
6268
}
69+
70+
// -----
71+
72+
// Make sure we do not crash on unsupported types.
73+
func.func @alloc_i128() {
74+
// expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
75+
%m = memref.alloc() : memref<4xi128, 1>
76+
return
77+
}
78+
79+
// -----
80+
81+
func.func @load_i128(%m: memref<4xi128, 1>) {
82+
%c0 = arith.constant 0 : index
83+
// expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
84+
%v = memref.load %m[%c0] : memref<4xi128, 1>
85+
return
86+
}
87+
88+
// -----
89+
90+
func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
91+
%c0 = arith.constant 0 : index
92+
// expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
93+
memref.store %c1, %m[%c0] : memref<4xi128, 1>
94+
return
95+
}

0 commit comments

Comments
 (0)