Skip to content

Commit 5857c76

Browse files
[mlir][LLVM] LLVMTypeConverter: Tighten materialization checks
1 parent a496ab4 commit 5857c76

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Helper function that checks if the given value range is a bare pointer.
157+
auto isBarePointer = [](ValueRange values) {
158+
return values.size() == 1 &&
159+
isa<LLVM::LLVMPointerType>(values.front().getType());
160+
};
161+
156162
// Argument materializations convert from the new block argument types
157163
// (multiple SSA values that make up a memref descriptor) back to the
158164
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
161167
addArgumentMaterialization([&](OpBuilder &builder,
162168
UnrankedMemRefType resultType,
163169
ValueRange inputs, Location loc) {
164-
if (inputs.size() == 1) {
165-
// Bare pointers are not supported for unranked memrefs because a
166-
// memref descriptor cannot be built just from a bare pointer.
170+
// Note: Bare pointers are not supported for unranked memrefs because a
171+
// memref descriptor cannot be built just from a bare pointer.
172+
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
167173
return Value();
168-
}
169174
Value desc =
170175
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171176
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
177182
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178183
ValueRange inputs, Location loc) {
179184
Value desc;
180-
if (inputs.size() == 1) {
181-
// This is a bare pointer. We allow bare pointers only for function entry
182-
// blocks.
183-
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184-
if (!barePtr)
185-
return Value();
186-
Block *block = barePtr.getOwner();
187-
if (!block->isEntryBlock() ||
188-
!isa<FunctionOpInterface>(block->getParentOp()))
189-
return Value();
185+
if (isBarePointer(inputs)) {
190186
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191187
inputs[0]);
192-
} else {
188+
} else if (TypeRange(inputs) ==
189+
getMemRefDescriptorFields(resultType,
190+
/*unpackAggregates=*/true)) {
193191
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192+
} else {
193+
// The inputs are neither a bare pointer nor an unpacked memref
194+
// descriptor. This materialization function cannot be used.
195+
return Value();
194196
}
195197
// An argument materialization must return a value of type `resultType`,
196198
// so insert a cast from the memref descriptor type (!llvm.struct) to the

0 commit comments

Comments
 (0)