Skip to content

Commit d8ed7e0

Browse files
committed
[mlir][memref]: Fix Bug in GlobalOp Verifier
When reconstructing the corresponding tensor type of a memref ensure we include the memory space of the tensor if it exists. Signed-off-by: Jack Frankland <[email protected]>
1 parent 148111f commit d8ed7e0

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
5959
/// type.
6060
Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
6161
if (auto memref = llvm::dyn_cast<MemRefType>(type))
62-
return RankedTensorType::get(memref.getShape(), memref.getElementType());
62+
return RankedTensorType::get(memref.getShape(), memref.getElementType(),
63+
memref.getMemorySpace());
6364
if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
6465
return UnrankedTensorType::get(memref.getElementType());
6566
return NoneType::get(type.getContext());

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ memref.global "priate" constant @memref5 : memref<2xf32> = uninitialized
342342

343343
// -----
344344

345+
// expected-error @+1 {{op initial value expected to be of type 'tensor<1xf16>', but was of type 'tensor<1xf16, 1 : i32>'}}
346+
"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16, 1 : i32>, sym_name = "memref6", sym_visibility = "private", type = memref<1xf16>}> : () -> ()
347+
348+
// -----
349+
345350
func.func @nonexistent_global_memref() {
346351
// expected-error @+1 {{'gv' does not reference a valid global memref}}
347352
%0 = memref.get_global @gv : memref<3xf32>

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ memref.global "private" @memref3 : memref<2xf32> = uninitialized
174174
// CHECK-LABEL: memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
175175
memref.global "private" constant @memref4 : memref<2xf32> = uninitialized
176176

177+
// CHECK-LABEL: memref.global "private" constant @memref5 : memref<1xf16, 1 : i32> = dense<1.000000e+00>
178+
"memref.global"() <{constant, initial_value = dense<1.000000e+00> : tensor<1xf16, 1 : i32>, sym_name = "memref5", sym_visibility = "private", type = memref<1xf16, 1 : i32>}> : () -> ()
179+
177180
// CHECK-LABEL: func @read_global_memref
178181
func.func @read_global_memref() {
179182
%0 = memref.get_global @memref0 : memref<2xf32>

0 commit comments

Comments
 (0)