Skip to content

Commit ea0f6bc

Browse files
[mlir][memref] Check memory space before lowering alloc ops
1 parent 5748ddb commit ea0f6bc

File tree

4 files changed

+13
-14
lines changed

4 files changed

+13
-14
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,11 @@ class ConvertToLLVMPattern : public ConversionPattern {
7575
ValueRange indices,
7676
ConversionPatternRewriter &rewriter) const;
7777

78-
/// Returns if the given memref has identity maps and the element type is
79-
/// convertible to LLVM.
80-
bool isConvertibleAndHasIdentityMaps(MemRefType type) const;
78+
/// Returns if the given memref type is convertible to LLVM and has an
79+
/// identity layout map. If `verifyMemorySpace` is set to "false", the memory
80+
/// space of the memref type is ignored.
81+
bool isConvertibleAndHasIdentityMaps(MemRefType type,
82+
bool verifyMemorySpace = true) const;
8183

8284
/// Returns the type of a pointer to an element of the memref.
8385
Type getElementPtrType(MemRefType type) const;

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,13 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
9898
// Check if the MemRefType `type` is supported by the lowering. We currently
9999
// only support memrefs with identity maps.
100100
bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps(
101-
MemRefType type) const {
102-
if (!typeConverter->convertType(type.getElementType()))
101+
MemRefType type, bool verifyMemorySpace) const {
102+
if (!type.getLayout().isIdentity())
103103
return false;
104-
return type.getLayout().isIdentity();
104+
// If the memory space should not be verified, just check the element type.
105+
Type typeToVerify =
106+
verifyMemorySpace ? static_cast<Type>(type) : type.getElementType();
107+
return static_cast<bool>(typeConverter->convertType(typeToVerify));
105108
}
106109

107110
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const {

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,7 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
7373
MemRefType memRefType = getMemRefResultType(op);
7474
// Allocate the underlying buffer.
7575
Type elementPtrType = this->getElementPtrType(memRefType);
76-
if (!elementPtrType) {
77-
emitError(loc, "conversion of memref memory space ")
78-
<< memRefType.getMemorySpace()
79-
<< " to integer address space "
80-
"failed. Consider adding memory space conversions.";
81-
}
76+
assert(elementPtrType && "could not compute element ptr type");
8277
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
8378
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
8479
getIndexType());

mlir/test/Conversion/MemRefToLLVM/invalid.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func.func @bad_address_space(%a: memref<2xindex, "foo">) {
2222

2323
// CHECK-LABEL: @invalid_int_conversion
2424
func.func @invalid_int_conversion() {
25-
// expected-error@+1 {{conversion of memref memory space 1 : ui64 to integer address space failed. Consider adding memory space conversions.}}
25+
// expected-error@unknown{{conversion of memref memory space 1 : ui64 to integer address space failed. Consider adding memory space conversions.}}
2626
%alloc = memref.alloc() {alignment = 64 : i64} : memref<10xf32, 1 : ui64>
2727
return
2828
}
@@ -32,7 +32,6 @@ func.func @invalid_int_conversion() {
3232
// expected-error@unknown{{conversion of memref memory space #gpu.address_space<workgroup> to integer address space failed. Consider adding memory space conversions}}
3333
// CHECK-LABEL: @issue_70160
3434
func.func @issue_70160() {
35-
// expected-error@+1{{conversion of memref memory space #gpu.address_space<workgroup> to integer address space failed. Consider adding memory space conversions}}
3635
%alloc = memref.alloc() : memref<1x32x33xi32, #gpu.address_space<workgroup>>
3736
%alloc1 = memref.alloc() : memref<i32>
3837
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)