Skip to content

Commit 3db6152

Browse files
committed
[mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type
This PR follows with #112104, using `nullptr` to indicate that type conversion failed and no fallback conversion should be attempted.
1 parent 11f625c commit 3db6152

File tree

5 files changed

+51
-21
lines changed

5 files changed

+51
-21
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
4040
addConversion([this](FunctionType ty) -> std::optional<Type> {
4141
SmallVector<Type> inputs;
4242
if (failed(convertTypes(ty.getInputs(), inputs)))
43-
return std::nullopt;
43+
return nullptr;
4444

4545
SmallVector<Type> results;
4646
if (failed(convertTypes(ty.getResults(), results)))
47-
return std::nullopt;
47+
return nullptr;
4848

4949
return FunctionType::get(ty.getContext(), inputs, results);
5050
});

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
169169
std::is_same<OpTy, memref::AllocaOp>(),
170170
"expected only memref::AllocOp or memref::AllocaOp");
171171
auto currentType = cast<MemRefType>(op.getMemref().getType());
172-
auto newResultType = dyn_cast<MemRefType>(
173-
this->getTypeConverter()->convertType(op.getType()));
172+
auto newResultType =
173+
this->getTypeConverter()->template convertType<MemRefType>(
174+
op.getType());
174175
if (!newResultType) {
175176
return rewriter.notifyMatchFailure(
176177
op->getLoc(),
@@ -378,7 +379,7 @@ struct ConvertMemRefReinterpretCast final
378379
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
379380
ConversionPatternRewriter &rewriter) const override {
380381
MemRefType newTy =
381-
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
382+
getTypeConverter()->convertType<MemRefType>(op.getType());
382383
if (!newTy) {
383384
return rewriter.notifyMatchFailure(
384385
op->getLoc(),
@@ -466,8 +467,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
466467
LogicalResult
467468
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
468469
ConversionPatternRewriter &rewriter) const override {
469-
MemRefType newTy = dyn_cast<MemRefType>(
470-
getTypeConverter()->convertType(subViewOp.getType()));
470+
MemRefType newTy =
471+
getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
471472
if (!newTy) {
472473
return rewriter.notifyMatchFailure(
473474
subViewOp->getLoc(),
@@ -632,14 +633,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
632633
SmallVector<int64_t> strides;
633634
int64_t offset;
634635
if (failed(getStridesAndOffset(ty, strides, offset)))
635-
return std::nullopt;
636+
return nullptr;
636637
if (!strides.empty() && strides.back() != 1)
637-
return std::nullopt;
638+
return nullptr;
638639

639640
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640641
intTy.getSignedness());
641642
if (!newElemTy)
642-
return std::nullopt;
643+
return nullptr;
643644

644645
StridedLayoutAttr layoutAttr;
645646
// If the offset is 0, we do not need a strided layout as the stride is

mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
159159

160160
Type newElemTy = typeConverter.convertType(intTy);
161161
if (!newElemTy)
162-
return std::nullopt;
162+
return nullptr;
163163

164164
return ty.cloneWith(std::nullopt, newElemTy);
165165
});

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
203203

204204
// -----
205205

206-
207206
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
208207
%c0 = arith.constant 0 : index
209208
%arr = memref.alloc() : memref<40x40xi4>
@@ -543,13 +542,15 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
543542

544543
// -----
545544

546-
!colMajor = memref<8x8xi4, strided<[1, 8]>>
547-
func.func @copy_distinct_layouts(%idx : index) -> i4 {
548-
%c0 = arith.constant 0 : index
549-
%arr = memref.alloc() : memref<8x8xi4>
550-
%arr2 = memref.alloc() : !colMajor
551-
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
552-
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
553-
%ld = memref.load %arr2[%c0, %c0] : !colMajor
554-
return %ld : i4
545+
func.func @alloc_non_contiguous() {
546+
// expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
547+
%arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
548+
return
549+
}
550+
551+
// -----
552+
553+
// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
554+
func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
555+
return
555556
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" \
2+
// RUN: --split-input-file --verify-diagnostics %s
3+
4+
// Make sure we do not crash on unsupported types.
5+
6+
func.func @alloc_i128() {
7+
// expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
8+
%m = memref.alloc() : memref<4xi128, 1>
9+
return
10+
}
11+
12+
// -----
13+
14+
func.func @load_i128(%m: memref<4xi128, 1>) {
15+
%c0 = arith.constant 0 : index
16+
// expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
17+
%v = memref.load %m[%c0] : memref<4xi128, 1>
18+
return
19+
}
20+
21+
// -----
22+
23+
func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
24+
%c0 = arith.constant 0 : index
25+
// expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
26+
memref.store %c1, %m[%c0] : memref<4xi128, 1>
27+
return
28+
}

0 commit comments

Comments
 (0)