Skip to content

[mlir][Conversion] FuncToLLVM: Simplify bare-pointer handling #96393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
}
}

/// Modifies the body of the function to construct the `MemRefDescriptor` from
/// the bare pointer calling convention lowering of `memref` types.
static void modifyFuncOpToUseBarePtrCallingConv(
ConversionPatternRewriter &rewriter, Location loc,
const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
TypeRange oldArgTypes) {
if (funcOp.getBody().empty())
return;

// Promote bare pointers from memref arguments to memref descriptors at the
// beginning of the function so that all the memrefs in the function have a
// uniform representation.
Block *entryBlock = &funcOp.getBody().front();
auto blockArgs = entryBlock->getArguments();
assert(blockArgs.size() == oldArgTypes.size() &&
"The number of arguments and types doesn't match");

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(entryBlock);
for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
BlockArgument arg = std::get<0>(it);
Type argTy = std::get<1>(it);

// Unranked memrefs are not supported in the bare pointer calling
// convention. We should have bailed out before in the presence of
// unranked memrefs.
assert(!isa<UnrankedMemRefType>(argTy) &&
"Unranked memref is not supported");
auto memrefTy = dyn_cast<MemRefType>(argTy);
if (!memrefTy)
continue;

// Replace barePtr with a placeholder (undef), promote barePtr to a ranked
// or unranked memref descriptor and replace placeholder with the last
// instruction of the memref descriptor.
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
// MemRef descriptor instructions. We may want to have a utility in the
// rewriter to properly handle this use case.
Location loc = funcOp.getLoc();
auto placeholder = rewriter.create<LLVM::UndefOp>(
loc, typeConverter.convertType(memrefTy));
rewriter.replaceUsesOfBlockArgument(arg, placeholder);

Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
memrefTy, arg);
rewriter.replaceOp(placeholder, {desc});
}
}

FailureOr<LLVM::LLVMFuncOp>
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
newFuncOp);
}
} else {
modifyFuncOpToUseBarePtrCallingConv(
rewriter, funcOp->getLoc(), converter, newFuncOp,
llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
}

return newFuncOp;
Expand Down
29 changes: 0 additions & 29 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
&signatureConversion)))
return failure();

// If bare memref pointers are being used, remap them back to memref
// descriptors This must be done after signature conversion to get rid of the
// unrealized casts.
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
for (const auto [idx, argTy] :
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
auto memrefTy = dyn_cast<MemRefType>(argTy);
if (!memrefTy)
continue;
assert(memrefTy.hasStaticShape() &&
"Bare pointer convertion used with dynamically-shaped memrefs");
// Use a placeholder when replacing uses of the memref argument to prevent
// circular replacements.
auto remapping = signatureConversion.getInputMapping(idx);
assert(remapping && remapping->size == 1 &&
"Type converter should produce 1-to-1 mapping for bare memrefs");
BlockArgument newArg =
llvmFuncOp.getBody().getArgument(remapping->inputNo);
auto placeholder = rewriter.create<LLVM::UndefOp>(
loc, getTypeConverter()->convertType(memrefTy));
rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
Value desc = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
rewriter.replaceOp(placeholder, {desc});
}
}

// Get memref type from function arguments and set the noalias to
// pointer arguments.
for (const auto [idx, argTy] :
Expand Down
22 changes: 17 additions & 5 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addArgumentMaterialization(
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
Location loc) -> std::optional<Value> {
if (inputs.size() == 1)
if (inputs.size() == 1) {
// Bare pointers are not supported for unranked memrefs because a
// memref descriptor cannot be built just from a bare pointer.
return std::nullopt;
}
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
inputs);
});
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
// TODO: bare ptr conversion could be handled here but we would need a way
// to distinguish between FuncOp and other regions.
if (inputs.size() == 1)
return std::nullopt;
if (inputs.size() == 1) {
// This is a bare pointer. We allow bare pointers only for function entry
// blocks.
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
if (!barePtr)
return std::nullopt;
Block *block = barePtr.getOwner();
if (!block->isEntryBlock() ||
!isa<FunctionOpInterface>(block->getParentOp()))
return std::nullopt;
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
inputs[0]);
}
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
});
// Add generic source and target materializations to handle cases where
Expand Down
Loading