Skip to content

Commit 2ac2e9a

Browse files
authored
[mlir][LLVM] Improve lowering of llvm.byval function arguments (#100028)
When a function argument is annotated with the `llvm.byval` attribute, [LLVM expects](https://llvm.org/docs/LangRef.html#parameter-attributes) the function argument type to be an `llvm.ptr`. For example: ``` func.func (%args0 : llvm.ptr {llvm.byval = !llvm.struct<(i32)>} { ... } ``` Unfortunately, this makes the type conversion context-dependent, which is something that the type conversion infrastructure (i.e., `LLVMTypeConverter` in this particular case) doesn't support. For example, we may want to convert `MyType` to `llvm.struct<(i32)>` in general, but to an `llvm.ptr` type only when it's a function argument passed by value. To fix this problem, this PR changes the FuncToLLVM conversion logic to generate an `llvm.ptr` when the function argument has a `llvm.byval` attribute. An `llvm.load` is inserted into the function to retrieve the value expected by the argument users.
1 parent c1912b4 commit 2ac2e9a

File tree

5 files changed

+184
-12
lines changed

5 files changed

+184
-12
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
namespace mlir {
2222

2323
class DataLayoutAnalysis;
24+
class FunctionOpInterface;
2425
class LowerToLLVMOptions;
2526

2627
namespace LLVM {
@@ -50,13 +51,25 @@ class LLVMTypeConverter : public TypeConverter {
5051
LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options,
5152
const DataLayoutAnalysis *analysis = nullptr);
5253

53-
/// Convert a function type. The arguments and results are converted one by
54+
/// Convert a function type. The arguments and results are converted one by
5455
/// one and results are packed into a wrapped LLVM IR structure type. `result`
5556
/// is populated with argument mapping.
5657
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
5758
bool useBarePtrCallConv,
5859
SignatureConversion &result) const;
5960

61+
/// Convert a function type. The arguments and results are converted one by
62+
/// one and results are packed into a wrapped LLVM IR structure type. `result`
63+
/// is populated with argument mapping. Converted types of `llvm.byval` and
64+
/// `llvm.byref` function arguments which are not LLVM pointers are overridden
65+
/// with LLVM pointers. Overridden arguments are returned in
66+
/// `byValRefNonPtrAttrs`.
67+
Type convertFunctionSignature(FunctionOpInterface funcOp, bool isVariadic,
68+
bool useBarePtrCallConv,
69+
LLVMTypeConverter::SignatureConversion &result,
70+
SmallVectorImpl<std::optional<NamedAttribute>>
71+
&byValRefNonPtrAttrs) const;
72+
6073
/// Convert a non-empty list of types to be returned from a function into an
6174
/// LLVM-compatible type. In particular, if more than one value is returned,
6275
/// create an LLVM dialect structure type with elements that correspond to
@@ -159,12 +172,26 @@ class LLVMTypeConverter : public TypeConverter {
159172
SmallVector<Type> &getCurrentThreadRecursiveStack();
160173

161174
private:
162-
/// Convert a function type. The arguments and results are converted one by
163-
/// one. Additionally, if the function returns more than one value, pack the
175+
/// Convert a function type. The arguments and results are converted one by
176+
/// one. Additionally, if the function returns more than one value, pack the
164177
/// results into an LLVM IR structure type so that the converted function type
165178
/// returns at most one result.
166179
Type convertFunctionType(FunctionType type) const;
167180

181+
/// Common implementation for `convertFunctionSignature` methods. Convert a
182+
/// function type. The arguments and results are converted one by one and
183+
/// results are packed into a wrapped LLVM IR structure type. `result` is
184+
/// populated with argument mapping. If `byValRefNonPtrAttrs` is provided,
185+
/// converted types of `llvm.byval` and `llvm.byref` function arguments which
186+
/// are not LLVM pointers are overridden with LLVM pointers. `llvm.byval` and
187+
/// `llvm.byref` arguments that were already converted to LLVM pointer types
188+
/// are removed from 'byValRefNonPtrAttrs`.
189+
Type convertFunctionSignatureImpl(
190+
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
191+
LLVMTypeConverter::SignatureConversion &result,
192+
SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs)
193+
const;
194+
168195
/// Convert the index type. Uses llvmModule data layout to create an integer
169196
/// of the pointer bitwidth.
170197
Type convertIndexType(IndexType type) const;

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,38 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
267267
}
268268
}
269269

270+
/// Inserts `llvm.load` ops in the function body to restore the expected pointee
271+
/// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272+
/// to LLVM pointer types.
273+
static void restoreByValRefArgumentType(
274+
ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275+
ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276+
LLVM::LLVMFuncOp funcOp) {
277+
// Nothing to do for function declarations.
278+
if (funcOp.isExternal())
279+
return;
280+
281+
ConversionPatternRewriter::InsertionGuard guard(rewriter);
282+
rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
283+
284+
for (const auto &[arg, byValRefAttr] :
285+
llvm::zip(funcOp.getArguments(), byValRefNonPtrAttrs)) {
286+
// Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287+
if (!byValRefAttr)
288+
continue;
289+
290+
// Insert load to retrieve the actual argument passed by value/reference.
291+
assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292+
"Expected LLVM pointer type for argument with "
293+
"`llvm.byval`/`llvm.byref` attribute");
294+
Type resTy = typeConverter.convertType(
295+
cast<TypeAttr>(byValRefAttr->getValue()).getValue());
296+
297+
auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298+
rewriter.replaceAllUsesExcept(arg, valueArg, valueArg);
299+
}
300+
}
301+
270302
FailureOr<LLVM::LLVMFuncOp>
271303
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
272304
ConversionPatternRewriter &rewriter,
@@ -280,10 +312,14 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
280312
// Convert the original function arguments. They are converted using the
281313
// LLVMTypeConverter provided to this legalization pattern.
282314
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
315+
// Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
316+
// overriden with an LLVM pointer type for later processing.
317+
SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
283318
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
284319
auto llvmType = converter.convertFunctionSignature(
285-
funcTy, varargsAttr && varargsAttr.getValue(),
286-
shouldUseBarePtrCallConv(funcOp, &converter), result);
320+
funcOp, varargsAttr && varargsAttr.getValue(),
321+
shouldUseBarePtrCallConv(funcOp, &converter), result,
322+
byValRefNonPtrAttrs);
287323
if (!llvmType)
288324
return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
289325

@@ -398,6 +434,12 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
398434
"region types conversion failed");
399435
}
400436

437+
// Fix the type mismatch between the materialized `llvm.ptr` and the expected
438+
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
439+
// function arguments.
440+
restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
441+
newFuncOp);
442+
401443
if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
402444
if (funcOp->getAttrOfType<UnitAttr>(
403445
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,42 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
270270
return LLVM::LLVMPointerType::get(type.getContext());
271271
}
272272

273+
/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
274+
/// function arguments. Returns an empty container if none of these attributes
275+
/// are found in any of the arguments.
276+
static void
277+
filterByValRefArgAttrs(FunctionOpInterface funcOp,
278+
SmallVectorImpl<std::optional<NamedAttribute>> &result) {
279+
assert(result.empty() && "Unexpected non-empty output");
280+
result.resize(funcOp.getNumArguments(), std::nullopt);
281+
bool foundByValByRefAttrs = false;
282+
for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
283+
for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
284+
if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
285+
namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
286+
foundByValByRefAttrs = true;
287+
result[argIdx] = namedAttr;
288+
break;
289+
}
290+
}
291+
}
292+
293+
if (!foundByValByRefAttrs)
294+
result.clear();
295+
}
296+
273297
// Function types are converted to LLVM Function types by recursively converting
274-
// argument and result types. If MLIR Function has zero results, the LLVM
275-
// Function has one VoidType result. If MLIR Function has more than one result,
298+
// argument and result types. If MLIR Function has zero results, the LLVM
299+
// Function has one VoidType result. If MLIR Function has more than one result,
276300
// they are into an LLVM StructType in their order of appearance.
277-
Type LLVMTypeConverter::convertFunctionSignature(
301+
// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
302+
// `llvm.byref` function arguments which are not LLVM pointers are overridden
303+
// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
304+
// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
305+
Type LLVMTypeConverter::convertFunctionSignatureImpl(
278306
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
279-
LLVMTypeConverter::SignatureConversion &result) const {
307+
LLVMTypeConverter::SignatureConversion &result,
308+
SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
280309
// Select the argument converter depending on the calling convention.
281310
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
282311
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
@@ -286,6 +315,19 @@ Type LLVMTypeConverter::convertFunctionSignature(
286315
SmallVector<Type, 8> converted;
287316
if (failed(funcArgConverter(*this, type, converted)))
288317
return {};
318+
319+
// Rewrite converted type of `llvm.byval` or `llvm.byref` function
320+
// argument that was not converted to an LLVM pointer types.
321+
if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
322+
converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
323+
// If the argument was already converted to an LLVM pointer type, we stop
324+
// tracking it as it doesn't need more processing.
325+
if (isa<LLVM::LLVMPointerType>(converted[0]))
326+
(*byValRefNonPtrAttrs)[idx] = std::nullopt;
327+
else
328+
converted[0] = LLVM::LLVMPointerType::get(&getContext());
329+
}
330+
289331
result.addInputs(idx, converted);
290332
}
291333

@@ -302,6 +344,27 @@ Type LLVMTypeConverter::convertFunctionSignature(
302344
isVariadic);
303345
}
304346

347+
Type LLVMTypeConverter::convertFunctionSignature(
348+
FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
349+
LLVMTypeConverter::SignatureConversion &result) const {
350+
return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
351+
result,
352+
/*byValRefNonPtrAttrs=*/nullptr);
353+
}
354+
355+
Type LLVMTypeConverter::convertFunctionSignature(
356+
FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
357+
LLVMTypeConverter::SignatureConversion &result,
358+
SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
359+
// Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
360+
// that were not converted to LLVM pointer types will be returned for further
361+
// processing.
362+
filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
363+
auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
364+
return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
365+
result, &byValRefNonPtrAttrs);
366+
}
367+
305368
/// Converts the function type to a C-compatible format, in particular using
306369
/// pointers to memref descriptors for arguments.
307370
std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>

mlir/test/Transforms/test-convert-func-op.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s
1+
// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
22

33
// CHECK-LABEL: llvm.func @add
44
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
@@ -10,3 +10,31 @@ func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface
1010
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
1111
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
1212
// CHECK-NEXT: llvm.return [[RES]]
13+
14+
// -----
15+
16+
// Test that `llvm.byval` arguments are converted to `llvm.ptr` and the actual
17+
// value is retrieved within the `llvm.func`.
18+
19+
// CHECK-LABEL: llvm.func @byval
20+
func.func @byval(%arg0: !test.smpla {llvm.byval = !test.smpla}) -> !test.smpla {
21+
return %arg0 : !test.smpla
22+
}
23+
24+
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byval = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
25+
// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
26+
// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>
27+
28+
// -----
29+
30+
// Test that `llvm.byref` arguments are converted to `llvm.ptr` and the actual
31+
// value is retrieved within the `llvm.func`.
32+
33+
// CHECK-LABEL: llvm.func @byref
34+
func.func @byref(%arg0: !test.smpla {llvm.byref = !test.smpla}) -> !test.smpla {
35+
return %arg0 : !test.smpla
36+
}
37+
38+
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr {llvm.byref = !llvm.struct<(i8, i8)>}) -> !llvm.struct<(i8, i8)>
39+
// CHECK: %[[LD:.*]] = llvm.load %[[ARG0]] : !llvm.ptr -> !llvm.struct<(i8, i8)>
40+
// CHECK: llvm.return %[[LD]] : !llvm.struct<(i8, i8)>

mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,23 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
4747
LogicalResult
4848
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
4949
ConversionPatternRewriter &rewriter) const override {
50-
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
51-
returnOp->getOperands());
50+
SmallVector<Type> resTys;
51+
if (failed(typeConverter->convertTypes(returnOp->getResultTypes(), resTys)))
52+
return failure();
53+
54+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, resTys,
55+
adaptor.getOperands());
5256
return success();
5357
}
5458
};
5559

60+
static std::optional<Type>
61+
convertSimpleATypeToStruct(test::SimpleAType simpleTy) {
62+
MLIRContext *ctx = simpleTy.getContext();
63+
SmallVector<Type> memberTys(2, IntegerType::get(ctx, /*width=*/8));
64+
return LLVM::LLVMStructType::getLiteral(ctx, memberTys);
65+
}
66+
5667
struct TestConvertFuncOp
5768
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
5869
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
@@ -74,6 +85,7 @@ struct TestConvertFuncOp
7485
LowerToLLVMOptions options(ctx);
7586
// Populate type conversions.
7687
LLVMTypeConverter typeConverter(ctx, options);
88+
typeConverter.addConversion(convertSimpleATypeToStruct);
7789

7890
RewritePatternSet patterns(ctx);
7991
patterns.add<FuncOpConversion>(typeConverter);

0 commit comments

Comments
 (0)