Skip to content

Commit 90dfdc7

Browse files
committed
[mlir][memref] Fix type conversion in emulate-wide-int and emulate-narrow-type
This PR follows with llvm#112104, using `nullptr` to indicate that type conversion failed and no fallback conversion should be attempted.
1 parent 11f625c commit 90dfdc7

File tree

6 files changed

+66
-38
lines changed

6 files changed

+66
-38
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
4040
addConversion([this](FunctionType ty) -> std::optional<Type> {
4141
SmallVector<Type> inputs;
4242
if (failed(convertTypes(ty.getInputs(), inputs)))
43-
return std::nullopt;
43+
return nullptr;
4444

4545
SmallVector<Type> results;
4646
if (failed(convertTypes(ty.getResults(), results)))
47-
return std::nullopt;
47+
return nullptr;
4848

4949
return FunctionType::get(ty.getContext(), inputs, results);
5050
});

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
169169
std::is_same<OpTy, memref::AllocaOp>(),
170170
"expected only memref::AllocOp or memref::AllocaOp");
171171
auto currentType = cast<MemRefType>(op.getMemref().getType());
172-
auto newResultType = dyn_cast<MemRefType>(
173-
this->getTypeConverter()->convertType(op.getType()));
172+
auto newResultType =
173+
this->getTypeConverter()->convertType<MemRefType>(op.getType());
174174
if (!newResultType) {
175175
return rewriter.notifyMatchFailure(
176176
op->getLoc(),
@@ -378,7 +378,7 @@ struct ConvertMemRefReinterpretCast final
378378
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
379379
ConversionPatternRewriter &rewriter) const override {
380380
MemRefType newTy =
381-
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
381+
getTypeConverter()->convertType<MemRefType>(op.getType());
382382
if (!newTy) {
383383
return rewriter.notifyMatchFailure(
384384
op->getLoc(),
@@ -466,8 +466,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
466466
LogicalResult
467467
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
468468
ConversionPatternRewriter &rewriter) const override {
469-
MemRefType newTy = dyn_cast<MemRefType>(
470-
getTypeConverter()->convertType(subViewOp.getType()));
469+
MemRefType newTy =
470+
getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
471471
if (!newTy) {
472472
return rewriter.notifyMatchFailure(
473473
subViewOp->getLoc(),
@@ -632,14 +632,14 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
632632
SmallVector<int64_t> strides;
633633
int64_t offset;
634634
if (failed(getStridesAndOffset(ty, strides, offset)))
635-
return std::nullopt;
635+
return nullptr;
636636
if (!strides.empty() && strides.back() != 1)
637-
return std::nullopt;
637+
return nullptr;
638638

639639
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
640640
intTy.getSignedness());
641641
if (!newElemTy)
642-
return std::nullopt;
642+
return nullptr;
643643

644644
StridedLayoutAttr layoutAttr;
645645
// If the offset is 0, we do not need a strided layout as the stride is

mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ void memref::populateMemRefWideIntEmulationConversions(
159159

160160
Type newElemTy = typeConverter.convertType(intTy);
161161
if (!newElemTy)
162-
return std::nullopt;
162+
return nullptr;
163163

164164
return ty.cloneWith(std::nullopt, newElemTy);
165165
});
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --verify-diagnostics --split-input-file %s
2+
3+
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
4+
%c0 = arith.constant 0 : index
5+
%arr = memref.alloc() : memref<40x40xi4>
6+
// expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
7+
%subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
8+
%ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
9+
return %ld : i4
10+
}
11+
12+
// -----
13+
14+
func.func @alloc_non_contiguous() {
15+
// expected-error @+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
16+
%arr = memref.alloc() : memref<8x8xi4, strided<[1, 8]>>
17+
return
18+
}
19+
20+
// -----
21+
22+
// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
23+
func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
24+
return
25+
}

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
2-
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
1+
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
2+
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
33

44
// Expect no conversions.
55
func.func @memref_i8() -> i8 {
@@ -203,18 +203,6 @@ func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
203203

204204
// -----
205205

206-
207-
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
208-
%c0 = arith.constant 0 : index
209-
%arr = memref.alloc() : memref<40x40xi4>
210-
// expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
211-
%subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
212-
%ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
213-
return %ld : i4
214-
}
215-
216-
// -----
217-
218206
func.func @reinterpret_cast_memref_load_0D() -> i4 {
219207
%0 = memref.alloc() : memref<5xi4>
220208
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
@@ -540,16 +528,3 @@ func.func @memref_copy_i4(%arg0: memref<32x128xi4, 1>, %arg1: memref<32x128xi4>)
540528
// CHECK32-SAME: %[[ARG0:.*]]: memref<512xi32, 1>, %[[ARG1:.*]]: memref<512xi32>
541529
// CHECK32: memref.copy %[[ARG0]], %[[ARG1]]
542530
// CHECK32: return
543-
544-
// -----
545-
546-
!colMajor = memref<8x8xi4, strided<[1, 8]>>
547-
func.func @copy_distinct_layouts(%idx : index) -> i4 {
548-
%c0 = arith.constant 0 : index
549-
%arr = memref.alloc() : memref<8x8xi4>
550-
%arr2 = memref.alloc() : !colMajor
551-
// expected-error @+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
552-
memref.copy %arr, %arr2 : memref<8x8xi4> to !colMajor
553-
%ld = memref.load %arr2[%c0, %c0] : !colMajor
554-
return %ld : i4
555-
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" \
2+
// RUN: --split-input-file --verify-diagnostics %s
3+
4+
// Make sure we do not crash on unsupported types.
5+
6+
func.func @alloc_i128() {
7+
// expected-error@+1 {{failed to legalize operation 'memref.alloc' that was explicitly marked illegal}}
8+
%m = memref.alloc() : memref<4xi128, 1>
9+
return
10+
}
11+
12+
// -----
13+
14+
func.func @load_i128(%m: memref<4xi128, 1>) {
15+
%c0 = arith.constant 0 : index
16+
// expected-error@+1 {{failed to legalize operation 'memref.load' that was explicitly marked illegal}}
17+
%v = memref.load %m[%c0] : memref<4xi128, 1>
18+
return
19+
}
20+
21+
// -----
22+
23+
func.func @store_i128(%c1: i128, %m: memref<4xi128, 1>) {
24+
%c0 = arith.constant 0 : index
25+
// expected-error@+1 {{failed to legalize operation 'memref.store' that was explicitly marked illegal}}
26+
memref.store %c1, %m[%c0] : memref<4xi128, 1>
27+
return
28+
}

0 commit comments

Comments
 (0)