Skip to content

[NFC][LLVM][CodeGen] Refactor MIR Printer #137361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/include/llvm/CodeGen/MachineBasicBlock.h
Original file line number Diff line number Diff line change
Expand Up @@ -1263,14 +1263,16 @@ class MachineBasicBlock
/// MachineBranchProbabilityInfo class.
BranchProbability getSuccProbability(const_succ_iterator Succ) const;

// Helper function for MIRPrinter.
bool canPredictBranchProbabilities() const;

private:
/// Return probability iterator corresponding to the I successor iterator.
probability_iterator getProbabilityIterator(succ_iterator I);
const_probability_iterator
getProbabilityIterator(const_succ_iterator I) const;

friend class MachineBranchProbabilityInfo;
friend class MIPrinter;

// Methods used to maintain doubly linked list of blocks...
friend struct ilist_callback_traits<MachineBasicBlock>;
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Support/BranchProbability.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
#define LLVM_SUPPORT_BRANCHPROBABILITY_H

#include "llvm/ADT/ADL.h"
#include "llvm/Support/DataTypes.h"
#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -62,6 +63,11 @@ class BranchProbability {
static void normalizeProbabilities(ProbabilityIter Begin,
ProbabilityIter End);

template <class ProbabilityContainer>
static void normalizeProbabilities(ProbabilityContainer &&R) {
normalizeProbabilities(adl_begin(R), adl_end(R));
}

uint32_t getNumerator() const { return N; }
static uint32_t getDenominator() { return D; }

Expand Down
158 changes: 50 additions & 108 deletions llvm/lib/CodeGen/MIRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/MIRYamlMapping.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
Expand Down Expand Up @@ -93,10 +94,6 @@ struct FrameIndexOperand {
}
};

} // end anonymous namespace

namespace llvm {

/// This class prints out the machine functions using the MIR serialization
/// format.
class MIRPrinter {
Expand Down Expand Up @@ -151,7 +148,6 @@ class MIPrinter {
/// Synchronization scope names registered with LLVMContext.
SmallVector<StringRef, 8> SSNs;

bool canPredictBranchProbabilities(const MachineBasicBlock &MBB) const;
bool canPredictSuccessors(const MachineBasicBlock &MBB) const;

public:
Expand All @@ -167,14 +163,13 @@ class MIPrinter {
void printStackObjectReference(int FrameIndex);
void print(const MachineInstr &MI, unsigned OpIdx,
const TargetRegisterInfo *TRI, const TargetInstrInfo *TII,
bool ShouldPrintRegisterTies, LLT TypeToPrint,
bool PrintDef = true);
bool ShouldPrintRegisterTies, SmallBitVector &PrintedTypes,
const MachineRegisterInfo &MRI, bool PrintDef = true);
};

} // end namespace llvm
} // end anonymous namespace

namespace llvm {
namespace yaml {
namespace llvm::yaml {

/// This struct serializes the LLVM IR module.
template <> struct BlockScalarTraits<Module> {
Expand All @@ -188,8 +183,7 @@ template <> struct BlockScalarTraits<Module> {
}
};

} // end namespace yaml
} // end namespace llvm
} // end namespace llvm::yaml

static void printRegMIR(Register Reg, yaml::StringValue &Dest,
const TargetRegisterInfo *TRI) {
Expand Down Expand Up @@ -327,9 +321,8 @@ static void printRegFlags(Register Reg,
const MachineFunction &MF,
const TargetRegisterInfo *TRI) {
auto FlagValues = TRI->getVRegFlagsOfReg(Reg, MF);
for (auto &Flag : FlagValues) {
for (auto &Flag : FlagValues)
RegisterFlags.push_back(yaml::FlowStringValue(Flag.str()));
}
}

void MIRPrinter::convert(yaml::MachineFunction &YamlMF,
Expand Down Expand Up @@ -618,9 +611,8 @@ void MIRPrinter::convertCalledGlobals(yaml::MachineFunction &YMF,
// Sort by position of call instructions.
llvm::sort(YMF.CalledGlobals.begin(), YMF.CalledGlobals.end(),
[](yaml::CalledGlobal A, yaml::CalledGlobal B) {
if (A.CallSite.BlockNum == B.CallSite.BlockNum)
return A.CallSite.Offset < B.CallSite.Offset;
return A.CallSite.BlockNum < B.CallSite.BlockNum;
return std::tie(A.CallSite.BlockNum, A.CallSite.Offset) <
std::tie(B.CallSite.BlockNum, B.CallSite.Offset);
});
}

Expand All @@ -630,11 +622,10 @@ void MIRPrinter::convert(yaml::MachineFunction &MF,
for (const MachineConstantPoolEntry &Constant : ConstantPool.getConstants()) {
std::string Str;
raw_string_ostream StrOS(Str);
if (Constant.isMachineConstantPoolEntry()) {
if (Constant.isMachineConstantPoolEntry())
Constant.Val.MachineCPVal->print(StrOS);
} else {
else
Constant.Val.ConstVal->printAsOperand(StrOS);
}

yaml::MachineConstantPoolValue YamlConstant;
YamlConstant.ID = ID++;
Expand Down Expand Up @@ -693,23 +684,6 @@ void llvm::guessSuccessors(const MachineBasicBlock &MBB,
IsFallthrough = I == MBB.end() || !I->isBarrier();
}

bool
MIPrinter::canPredictBranchProbabilities(const MachineBasicBlock &MBB) const {
if (MBB.succ_size() <= 1)
return true;
if (!MBB.hasSuccessorProbabilities())
return true;

SmallVector<BranchProbability,8> Normalized(MBB.Probs.begin(),
MBB.Probs.end());
BranchProbability::normalizeProbabilities(Normalized.begin(),
Normalized.end());
SmallVector<BranchProbability,8> Equal(Normalized.size());
BranchProbability::normalizeProbabilities(Equal.begin(), Equal.end());

return std::equal(Normalized.begin(), Normalized.end(), Equal.begin());
}

bool MIPrinter::canPredictSuccessors(const MachineBasicBlock &MBB) const {
SmallVector<MachineBasicBlock*,8> GuessedSuccs;
bool GuessedFallthrough;
Expand Down Expand Up @@ -738,7 +712,7 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {

bool HasLineAttributes = false;
// Print the successors
bool canPredictProbs = canPredictBranchProbabilities(MBB);
bool canPredictProbs = MBB.canPredictBranchProbabilities();
// Even if the list of successors is empty, if we cannot guess it,
// we need to print it to tell the parser that the list is empty.
// This is needed, because MI model unreachable as empty blocks
Expand All @@ -750,14 +724,12 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
OS.indent(2) << "successors:";
if (!MBB.succ_empty())
OS << " ";
ListSeparator LS;
for (auto I = MBB.succ_begin(), E = MBB.succ_end(); I != E; ++I) {
if (I != MBB.succ_begin())
OS << ", ";
OS << printMBBReference(**I);
OS << LS << printMBBReference(**I);
if (!SimplifyMIR || !canPredictProbs)
OS << '('
<< format("0x%08" PRIx32, MBB.getSuccProbability(I).getNumerator())
<< ')';
OS << format("(0x%08" PRIx32 ")",
MBB.getSuccProbability(I).getNumerator());
}
OS << "\n";
HasLineAttributes = true;
Expand All @@ -768,12 +740,9 @@ void MIPrinter::print(const MachineBasicBlock &MBB) {
if (!MBB.livein_empty()) {
const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo();
OS.indent(2) << "liveins: ";
bool First = true;
ListSeparator LS;
for (const auto &LI : MBB.liveins_dbg()) {
if (!First)
OS << ", ";
First = false;
OS << printReg(LI.PhysReg, &TRI);
OS << LS << printReg(LI.PhysReg, &TRI);
if (!LI.LaneMask.all())
OS << ":0x" << PrintLaneMask(LI.LaneMask);
}
Expand Down Expand Up @@ -814,14 +783,14 @@ void MIPrinter::print(const MachineInstr &MI) {

SmallBitVector PrintedTypes(8);
bool ShouldPrintRegisterTies = MI.hasComplexRegisterTies();
ListSeparator LS;
unsigned I = 0, E = MI.getNumOperands();
for (; I < E && MI.getOperand(I).isReg() && MI.getOperand(I).isDef() &&
!MI.getOperand(I).isImplicit();
++I) {
if (I)
OS << ", ";
print(MI, I, TRI, TII, ShouldPrintRegisterTies,
MI.getTypeToPrint(I, PrintedTypes, MRI),
for (; I < E; ++I) {
const MachineOperand MO = MI.getOperand(I);
if (!MO.isReg() || !MO.isDef() || MO.isImplicit())
break;
OS << LS;
print(MI, I, TRI, TII, ShouldPrintRegisterTies, PrintedTypes, MRI,
/*PrintDef=*/false);
}

Expand Down Expand Up @@ -869,74 +838,48 @@ void MIPrinter::print(const MachineInstr &MI) {
OS << "samesign ";

OS << TII->getName(MI.getOpcode());
if (I < E)
OS << ' ';

bool NeedComma = false;
for (; I < E; ++I) {
if (NeedComma)
OS << ", ";
print(MI, I, TRI, TII, ShouldPrintRegisterTies,
MI.getTypeToPrint(I, PrintedTypes, MRI));
NeedComma = true;
LS = ListSeparator();

if (I < E) {
OS << ' ';
for (; I < E; ++I) {
OS << LS;
print(MI, I, TRI, TII, ShouldPrintRegisterTies, PrintedTypes, MRI);
}
}

// Print any optional symbols attached to this instruction as-if they were
// operands.
if (MCSymbol *PreInstrSymbol = MI.getPreInstrSymbol()) {
if (NeedComma)
OS << ',';
OS << " pre-instr-symbol ";
OS << LS << " pre-instr-symbol ";
MachineOperand::printSymbol(OS, *PreInstrSymbol);
NeedComma = true;
}
if (MCSymbol *PostInstrSymbol = MI.getPostInstrSymbol()) {
if (NeedComma)
OS << ',';
OS << " post-instr-symbol ";
OS << LS << " post-instr-symbol ";
MachineOperand::printSymbol(OS, *PostInstrSymbol);
NeedComma = true;
}
if (MDNode *HeapAllocMarker = MI.getHeapAllocMarker()) {
if (NeedComma)
OS << ',';
OS << " heap-alloc-marker ";
OS << LS << " heap-alloc-marker ";
HeapAllocMarker->printAsOperand(OS, MST);
NeedComma = true;
}
if (MDNode *PCSections = MI.getPCSections()) {
if (NeedComma)
OS << ',';
OS << " pcsections ";
OS << LS << " pcsections ";
PCSections->printAsOperand(OS, MST);
NeedComma = true;
}
if (MDNode *MMRA = MI.getMMRAMetadata()) {
if (NeedComma)
OS << ',';
OS << " mmra ";
OS << LS << " mmra ";
MMRA->printAsOperand(OS, MST);
NeedComma = true;
}
if (uint32_t CFIType = MI.getCFIType()) {
if (NeedComma)
OS << ',';
OS << " cfi-type " << CFIType;
NeedComma = true;
}
if (uint32_t CFIType = MI.getCFIType())
OS << LS << " cfi-type " << CFIType;

if (auto Num = MI.peekDebugInstrNum()) {
if (NeedComma)
OS << ',';
OS << " debug-instr-number " << Num;
NeedComma = true;
}
if (auto Num = MI.peekDebugInstrNum())
OS << LS << " debug-instr-number " << Num;

if (PrintLocations) {
if (const DebugLoc &DL = MI.getDebugLoc()) {
if (NeedComma)
OS << ',';
OS << " debug-location ";
OS << LS << " debug-location ";
DL->printAsOperand(OS, MST);
}
}
Expand All @@ -945,12 +888,10 @@ void MIPrinter::print(const MachineInstr &MI) {
OS << " :: ";
const LLVMContext &Context = MF->getFunction().getContext();
const MachineFrameInfo &MFI = MF->getFrameInfo();
bool NeedComma = false;
LS = ListSeparator();
for (const auto *Op : MI.memoperands()) {
if (NeedComma)
OS << ", ";
OS << LS;
Op->print(OS, MST, SSNs, Context, &MFI, TII);
NeedComma = true;
}
}
}
Expand All @@ -971,10 +912,11 @@ static std::string formatOperandComment(std::string Comment) {
}

void MIPrinter::print(const MachineInstr &MI, unsigned OpIdx,
const TargetRegisterInfo *TRI,
const TargetInstrInfo *TII,
bool ShouldPrintRegisterTies, LLT TypeToPrint,
bool PrintDef) {
const TargetRegisterInfo *TRI, const TargetInstrInfo *TII,
bool ShouldPrintRegisterTies,
SmallBitVector &PrintedTypes,
const MachineRegisterInfo &MRI, bool PrintDef) {
LLT TypeToPrint = MI.getTypeToPrint(OpIdx, PrintedTypes, MRI);
const MachineOperand &Op = MI.getOperand(OpIdx);
std::string MOComment = TII->createMIROperandComment(MI, Op, OpIdx, TRI);

Expand Down
6 changes: 1 addition & 5 deletions llvm/lib/CodeGen/MIRPrintingPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ char MIRPrintingPass::ID = 0;
char &llvm::MIRPrintingPassID = MIRPrintingPass::ID;
INITIALIZE_PASS(MIRPrintingPass, "mir-printer", "MIR Printer", false, false)

namespace llvm {

MachineFunctionPass *createPrintMIRPass(raw_ostream &OS) {
MachineFunctionPass *llvm::createPrintMIRPass(raw_ostream &OS) {
return new MIRPrintingPass(OS);
}

} // end namespace llvm
Loading