Skip to content

Commit 705f048

Browse files
committed
[mlir] MemRefToLLVM: convert memref.view operations for empty memrefs
Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D126094
1 parent 8801a5d commit 705f048

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,6 +1848,12 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
18481848
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
18491849
assert(offset == 0 && "expected offset to be 0");
18501850

1851+
// Target memref must be contiguous in memory (innermost stride is 1), or
1852+
// empty (special case when at least one of the memref dimensions is 0).
1853+
if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1854+
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1855+
failure();
1856+
18511857
// Create the descriptor.
18521858
MemRefDescriptor sourceMemRef(adaptor.source());
18531859
auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -1884,9 +1890,6 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
18841890
return rewriter.replaceOp(viewOp, {targetMemRef}), success();
18851891

18861892
// Fields 4 and 5: Update sizes and strides.
1887-
if (strides.back() != 1)
1888-
return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1889-
failure();
18901893
Value stride = nullptr, nextSize = nullptr;
18911894
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
18921895
// Update size.

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,29 @@ func.func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
8989

9090
// -----
9191

92+
// CHECK-LABL: func @view_empty_memref(
93+
// CHECK: %[[ARG0:.*]]: index,
94+
// CHECK: %[[ARG1:.*]]: memref<0xi8>)
95+
func.func @view_empty_memref(%offset: index, %mem: memref<0xi8>) {
96+
97+
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
98+
// CHECK: llvm.mlir.constant(0 : index) : i64
99+
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
100+
// CHECK: llvm.mlir.constant(4 : index) : i64
101+
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
102+
// CHECK: llvm.mlir.constant(0 : index) : i64
103+
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
104+
// CHECK: llvm.mlir.constant(0 : index) : i64
105+
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
106+
// CHECK: llvm.mlir.constant(0 : index) : i64
107+
// CHECK: = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
108+
%0 = memref.view %mem[%offset][] : memref<0xi8> to memref<0x4xf32>
109+
110+
return
111+
}
112+
113+
// -----
114+
92115
// CHECK-LABEL: func @subview(
93116
// CHECK: %[[MEM:.*]]: memref<{{.*}}>,
94117
// CHECK: %[[ARG0f:[a-zA-Z0-9]*]]: index,

0 commit comments

Comments
 (0)