Skip to content

[mlir] Add pass to add comdat to all linkonce functions #65270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/AddComdats.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- AddComdats.h - Add comdats to linkonce functions -*- C++ -*---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H

#include <memory>

namespace mlir {

class Pass;

namespace LLVM {

#define GEN_PASS_DECL_LLVMADDCOMDATS
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"

} // namespace LLVM
} // namespace mlir

#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_ADDCOMDATS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H

#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h"
#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@

include "mlir/Pass/PassBase.td"

def LLVMAddComdats : Pass<"llvm-add-comdats", "::mlir::ModuleOp"> {
let summary = "Add comdats to linkonce and linkonce_odr functions";
let description = [{
Add an any COMDAT to every linkonce and linkonce_odr function.
This is necessary on Windows to link these functions as the system
linker won't link weak symbols without a COMDAT. It also provides better
behavior than standard weak symbols on ELF-based platforms.
This pass will still add COMDATs on platforms that do not support them,
for example macOS, so should only be run when the target platform supports
COMDATs.
}];
}

def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> {
let summary = "Legalize LLVM dialect to be convertible to LLVM IR";
let constructor = "::mlir::LLVM::createLegalizeForExportPass()";
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/AddComdats.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===- AddComdats.cpp - Add comdats to linkonce functions -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace LLVM {
#define GEN_PASS_DEF_LLVMADDCOMDATS
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
} // namespace LLVM
} // namespace mlir

using namespace mlir;

static void addComdat(LLVM::LLVMFuncOp &op, OpBuilder &builder,
SymbolTable &symbolTable, ModuleOp &module) {
const char *comdatName = "__llvm_comdat";
mlir::LLVM::ComdatOp comdatOp =
symbolTable.lookup<mlir::LLVM::ComdatOp>(comdatName);
if (!comdatOp) {
PatternRewriter::InsertionGuard guard(builder);
builder.setInsertionPointToStart(module.getBody());
comdatOp =
builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
symbolTable.insert(comdatOp);
}

PatternRewriter::InsertionGuard guard(builder);
builder.setInsertionPointToStart(&comdatOp.getBody().back());
auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>(
comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
op.setComdatAttr(mlir::SymbolRefAttr::get(
builder.getContext(), comdatName,
mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr())));
}

namespace {
struct AddComdatsPass : public LLVM::impl::LLVMAddComdatsBase<AddComdatsPass> {
void runOnOperation() override {
OpBuilder builder{&getContext()};
ModuleOp mod = getOperation();

std::unique_ptr<SymbolTable> symbolTable;
auto getSymTab = [&]() -> SymbolTable & {
if (!symbolTable)
symbolTable = std::make_unique<SymbolTable>(mod);
return *symbolTable;
};
for (auto op : mod.getBody()->getOps<LLVM::LLVMFuncOp>()) {
if (op.getLinkage() == LLVM::Linkage::Linkonce ||
op.getLinkage() == LLVM::Linkage::LinkonceODR) {
addComdat(op, builder, getSymTab(), mod);
}
}
}
};
} // namespace
1 change: 1 addition & 0 deletions mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLLVMIRTransforms
AddComdats.cpp
DIScopeForLLVMFuncOp.cpp
LegalizeForExport.cpp
OptimizeForNVVM.cpp
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Dialect/LLVMIR/add-linkonce-comdat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: mlir-opt -llvm-add-comdats -verify-diagnostics %s | FileCheck %s

// CHECK: llvm.comdat @__llvm_comdat {
// CHECK-DAG: llvm.comdat_selector @linkonce any
// CHECK-DAG: llvm.comdat_selector @linkonce_odr any
// CHECK: }

// CHECK: llvm.func linkonce @linkonce() comdat(@__llvm_comdat::@linkonce)
llvm.func linkonce @linkonce() {
llvm.return
}

// CHECK: llvm.func linkonce_odr @linkonce_odr() comdat(@__llvm_comdat::@linkonce_odr)
llvm.func linkonce_odr @linkonce_odr() {
llvm.return
}