@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153
153
type.isVarArg ());
154
154
});
155
155
156
+ // Helper function that checks if the given value range is a bare pointer.
157
+ auto isBarePointer = [](ValueRange values) {
158
+ return values.size () == 1 &&
159
+ isa<LLVM::LLVMPointerType>(values.front ().getType ());
160
+ };
161
+
156
162
// Argument materializations convert from the new block argument types
157
163
// (multiple SSA values that make up a memref descriptor) back to the
158
164
// original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
161
167
addArgumentMaterialization ([&](OpBuilder &builder,
162
168
UnrankedMemRefType resultType,
163
169
ValueRange inputs, Location loc) {
164
- if (inputs. size () == 1 ) {
165
- // Bare pointers are not supported for unranked memrefs because a
166
- // memref descriptor cannot be built just from a bare pointer.
170
+ // Note: Bare pointers are not supported for unranked memrefs because a
171
+ // memref descriptor cannot be built just from a bare pointer.
172
+ if ( TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
167
173
return Value ();
168
- }
169
174
Value desc =
170
175
UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171
176
// An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
177
182
addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178
183
ValueRange inputs, Location loc) {
179
184
Value desc;
180
- if (inputs.size () == 1 ) {
181
- // This is a bare pointer. We allow bare pointers only for function entry
182
- // blocks.
183
- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
184
- if (!barePtr)
185
- return Value ();
186
- Block *block = barePtr.getOwner ();
187
- if (!block->isEntryBlock () ||
188
- !isa<FunctionOpInterface>(block->getParentOp ()))
189
- return Value ();
185
+ if (isBarePointer (inputs)) {
190
186
desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
191
187
inputs[0 ]);
192
- } else {
188
+ } else if (TypeRange (inputs) ==
189
+ getMemRefDescriptorFields (resultType,
190
+ /* unpackAggregates=*/ true )) {
193
191
desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
192
+ } else {
193
+ // The inputs are neither a bare pointer nor an unpacked memref
194
+ // descriptor. This materialization function cannot be used.
195
+ return Value ();
194
196
}
195
197
// An argument materialization must return a value of type `resultType`,
196
198
// so insert a cast from the memref descriptor type (!llvm.struct) to the
0 commit comments