Skip to content

[GlobalISel] Improve Handling of Immediates in Apply MIR Patterns #66071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/docs/GlobalISel/MIRPatterns.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ Common Pattern #3: Emitting a Constant Value
When an immediate operand appears in an 'apply' pattern, the behavior
depends on whether it's typed or not.

* If the immediate is typed, a ``G_CONSTANT`` is implicitly emitted
(= a register operand is added to the instruction).
* If the immediate is typed, ``MachineIRBuilder::buildConstant`` is used
to create a ``G_CONSTANT``. A ``G_BUILD_VECTOR`` will be used for vectors.
* If the immediate is untyped, a simple immediate is added
(``MachineInstrBuilder::addImm``).

Expand Down
60 changes: 43 additions & 17 deletions llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/Bitset.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/LowLevelType.h"
#include "llvm/CodeGen/MachineFunction.h"
Expand All @@ -40,6 +41,7 @@ class APInt;
class APFloat;
class GISelKnownBits;
class MachineInstr;
class MachineIRBuilder;
class MachineInstrBuilder;
class MachineFunction;
class MachineOperand;
Expand Down Expand Up @@ -274,6 +276,12 @@ enum {
/// - StoreIdx - Store location in RecordedOperands.
GIM_RecordNamedOperand,

/// TODO DESC
/// - InsnID - Instruction ID
/// - OpIdx - Operand index
/// - TempTypeIdx - Temp Type Index, always negative.
GIM_RecordRegType,

/// Fail the current try-block, or completely fail to match if there is no
/// current try-block.
GIM_Reject,
Expand All @@ -291,6 +299,11 @@ enum {
/// - Opcode - The new opcode to use
GIR_BuildMI,

/// Builds a constant and stores its result in a TempReg.
/// - TempRegID - Temp Register to define.
/// - Imm - The immediate to add
GIR_BuildConstant,

/// Copy an operand to the specified instruction
/// - NewInsnID - Instruction ID to modify
/// - OldInsnID - Instruction ID to copy from
Expand Down Expand Up @@ -350,12 +363,6 @@ enum {
/// - Imm - The immediate to add
GIR_AddImm,

/// Add an CImm to the specified instruction
/// - InsnID - Instruction ID to modify
/// - Ty - Type of the constant immediate.
/// - Imm - The immediate to add
GIR_AddCImm,

/// Render complex operands to the specified instruction
/// - InsnID - Instruction ID to modify
/// - RendererID - The renderer to call
Expand Down Expand Up @@ -501,10 +508,25 @@ class GIMatchTableExecutor {
}

protected:
/// Observer used by \ref executeMatchTable to record all instructions created
/// by the rule.
class GIMatchTableObserver : public GISelChangeObserver {
public:
virtual ~GIMatchTableObserver();

void erasingInstr(MachineInstr &MI) override { CreatedInsts.erase(&MI); }
void createdInstr(MachineInstr &MI) override { CreatedInsts.insert(&MI); }
void changingInstr(MachineInstr &MI) override {}
void changedInstr(MachineInstr &MI) override {}

// Keeps track of all instructions that have been created when applying a
// rule.
SmallDenseSet<MachineInstr *, 4> CreatedInsts;
};

using ComplexRendererFns =
std::optional<SmallVector<std::function<void(MachineInstrBuilder &)>, 4>>;
using RecordedMIVector = SmallVector<MachineInstr *, 4>;
using NewMIVector = SmallVector<MachineInstrBuilder, 4>;

struct MatcherState {
std::vector<ComplexRendererFns::value_type> Renderers;
Expand All @@ -516,6 +538,10 @@ class GIMatchTableExecutor {
/// list. Currently such predicates don't have more then 3 arguments.
std::array<const MachineOperand *, 3> RecordedOperands;

/// Types extracted from an instruction's operand.
/// Whenever a type index is negative, we look here instead.
SmallVector<LLT, 4> RecordedTypes;

MatcherState(unsigned MaxRenderers);
};

Expand Down Expand Up @@ -555,15 +581,15 @@ class GIMatchTableExecutor {
/// and false otherwise.
template <class TgtExecutor, class PredicateBitset, class ComplexMatcherMemFn,
class CustomRendererFn>
bool executeMatchTable(
TgtExecutor &Exec, NewMIVector &OutMIs, MatcherState &State,
const ExecInfoTy<PredicateBitset, ComplexMatcherMemFn, CustomRendererFn>
&ISelInfo,
const int64_t *MatchTable, const TargetInstrInfo &TII,
MachineRegisterInfo &MRI, const TargetRegisterInfo &TRI,
const RegisterBankInfo &RBI, const PredicateBitset &AvailableFeatures,
CodeGenCoverage *CoverageInfo,
GISelChangeObserver *Observer = nullptr) const;
bool executeMatchTable(TgtExecutor &Exec, MatcherState &State,
const ExecInfoTy<PredicateBitset, ComplexMatcherMemFn,
CustomRendererFn> &ExecInfo,
MachineIRBuilder &Builder, const int64_t *MatchTable,
const TargetInstrInfo &TII, MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI,
const RegisterBankInfo &RBI,
const PredicateBitset &AvailableFeatures,
CodeGenCoverage *CoverageInfo) const;

virtual const int64_t *getMatchTable() const {
llvm_unreachable("Should have been overridden by tablegen if used");
Expand Down Expand Up @@ -592,7 +618,7 @@ class GIMatchTableExecutor {
}

virtual void runCustomAction(unsigned, const MatcherState &State,
NewMIVector &OutMIs) const {
ArrayRef<MachineInstrBuilder> OutMIs) const {
llvm_unreachable("Subclass does not implement runCustomAction!");
}

Expand Down
125 changes: 79 additions & 46 deletions llvm/include/llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h"
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineOperand.h"
Expand All @@ -42,17 +43,33 @@ namespace llvm {
template <class TgtExecutor, class PredicateBitset, class ComplexMatcherMemFn,
class CustomRendererFn>
bool GIMatchTableExecutor::executeMatchTable(
TgtExecutor &Exec, NewMIVector &OutMIs, MatcherState &State,
TgtExecutor &Exec, MatcherState &State,
const ExecInfoTy<PredicateBitset, ComplexMatcherMemFn, CustomRendererFn>
&ExecInfo,
const int64_t *MatchTable, const TargetInstrInfo &TII,
MachineRegisterInfo &MRI, const TargetRegisterInfo &TRI,
const RegisterBankInfo &RBI, const PredicateBitset &AvailableFeatures,
CodeGenCoverage *CoverageInfo, GISelChangeObserver *Observer) const {
MachineIRBuilder &Builder, const int64_t *MatchTable,
const TargetInstrInfo &TII, MachineRegisterInfo &MRI,
const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI,
const PredicateBitset &AvailableFeatures,
CodeGenCoverage *CoverageInfo) const {

// Setup observer
GIMatchTableObserver MTObserver;
GISelObserverWrapper Observer(&MTObserver);
if (auto *CurObs = Builder.getChangeObserver())
Observer.addObserver(CurObs);

// TODO: Set MF delegate?

// Setup builder.
auto RestoreOldObserver = Builder.setTemporaryChangeObserver(Observer);

uint64_t CurrentIdx = 0;
SmallVector<uint64_t, 4> OnFailResumeAt;

// We also record MachineInstrs manually in this vector so opcodes can address
// them.
SmallVector<MachineInstrBuilder, 4> OutMIs;

// Bypass the flag check on the instruction, and only look at the MCInstrDesc.
bool NoFPException = !State.MIs[0]->getDesc().mayRaiseFPException();

Expand All @@ -71,19 +88,29 @@ bool GIMatchTableExecutor::executeMatchTable(
return RejectAndResume;
};

auto propagateFlags = [=](NewMIVector &OutMIs) {
for (auto MIB : OutMIs) {
auto propagateFlags = [&]() {
for (auto *MI : MTObserver.CreatedInsts) {
// Set the NoFPExcept flag when no original matched instruction could
// raise an FP exception, but the new instruction potentially might.
uint16_t MIBFlags = Flags;
if (NoFPException && MIB->mayRaiseFPException())
if (NoFPException && MI->mayRaiseFPException())
MIBFlags |= MachineInstr::NoFPExcept;
MIB.setMIFlags(MIBFlags);
Observer.changingInstr(*MI);
MI->setFlags(MIBFlags);
Observer.changedInstr(*MI);
}

return true;
};

// If the index is >= 0, it's an index in the type objects generated by TableGen.
// If the index is <0, it's an index in the recorded types object.
auto getTypeFromIdx = [&](int64_t Idx) -> const LLT& {
if(Idx >= 0)
return ExecInfo.TypeObjects[Idx];
return State.RecordedTypes[1 - Idx];
};

while (true) {
assert(CurrentIdx != ~0u && "Invalid MatchTable index");
int64_t MatcherOpcode = MatchTable[CurrentIdx++];
Expand Down Expand Up @@ -620,7 +647,7 @@ bool GIMatchTableExecutor::executeMatchTable(
assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
MachineOperand &MO = State.MIs[InsnID]->getOperand(OpIdx);
if (!MO.isReg() ||
MRI.getType(MO.getReg()) != ExecInfo.TypeObjects[TypeID]) {
MRI.getType(MO.getReg()) != getTypeFromIdx(TypeID)) {
if (handleReject() == RejectAndGiveUp)
return false;
}
Expand Down Expand Up @@ -671,6 +698,25 @@ bool GIMatchTableExecutor::executeMatchTable(
State.RecordedOperands[StoreIdx] = &State.MIs[InsnID]->getOperand(OpIdx);
break;
}
case GIM_RecordRegType: {
int64_t InsnID = MatchTable[CurrentIdx++];
int64_t OpIdx = MatchTable[CurrentIdx++];
int64_t TypeIdx = MatchTable[CurrentIdx++];

DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIM_RecordRegType(MIs["
<< InsnID << "]->getOperand(" << OpIdx
<< "), TypeIdx=" << TypeIdx << ")\n");
assert(State.MIs[InsnID] != nullptr && "Used insn before defined");
assert(TypeIdx <= 0 && "Temp types always have negative indexes!");
// Indexes start at -1.
TypeIdx = 1 - TypeIdx;
const auto& Op = State.MIs[InsnID]->getOperand(OpIdx);
if(State.RecordedTypes.size() <= (uint64_t)TypeIdx)
State.RecordedTypes.resize(TypeIdx + 1, LLT());
State.RecordedTypes[TypeIdx] = MRI.getType(Op.getReg());
break;
}
case GIM_CheckRegBankForClass: {
int64_t InsnID = MatchTable[CurrentIdx++];
int64_t OpIdx = MatchTable[CurrentIdx++];
Expand Down Expand Up @@ -901,6 +947,7 @@ bool GIMatchTableExecutor::executeMatchTable(
OutMIs[NewInsnID] = MachineInstrBuilder(*State.MIs[OldInsnID]->getMF(),
State.MIs[OldInsnID]);
OutMIs[NewInsnID]->setDesc(TII.get(NewOpcode));
MTObserver.CreatedInsts.insert(OutMIs[NewInsnID]);
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_MutateOpcode(OutMIs["
<< NewInsnID << "], MIs[" << OldInsnID << "], "
Expand All @@ -914,14 +961,23 @@ bool GIMatchTableExecutor::executeMatchTable(
if (NewInsnID >= OutMIs.size())
OutMIs.resize(NewInsnID + 1);

OutMIs[NewInsnID] = BuildMI(*State.MIs[0]->getParent(), State.MIs[0],
MIMetadata(*State.MIs[0]), TII.get(Opcode));
OutMIs[NewInsnID] = Builder.buildInstr(Opcode);
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_BuildMI(OutMIs["
<< NewInsnID << "], " << Opcode << ")\n");
break;
}

case GIR_BuildConstant: {
int64_t TempRegID = MatchTable[CurrentIdx++];
int64_t Imm = MatchTable[CurrentIdx++];
Builder.buildConstant(State.TempRegisters[TempRegID], Imm);
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_BuildConstant(TempReg["
<< TempRegID << "], Imm=" << Imm << ")\n");
break;
}

case GIR_Copy: {
int64_t NewInsnID = MatchTable[CurrentIdx++];
int64_t OldInsnID = MatchTable[CurrentIdx++];
Expand Down Expand Up @@ -1047,24 +1103,6 @@ bool GIMatchTableExecutor::executeMatchTable(
<< "], " << Imm << ")\n");
break;
}

case GIR_AddCImm: {
int64_t InsnID = MatchTable[CurrentIdx++];
int64_t TypeID = MatchTable[CurrentIdx++];
int64_t Imm = MatchTable[CurrentIdx++];
assert(OutMIs[InsnID] && "Attempted to add to undefined instruction");

unsigned Width = ExecInfo.TypeObjects[TypeID].getScalarSizeInBits();
LLVMContext &Ctx = MF->getFunction().getContext();
OutMIs[InsnID].addCImm(
ConstantInt::get(IntegerType::get(Ctx, Width), Imm, /*signed*/ true));
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_AddCImm(OutMIs[" << InsnID
<< "], TypeID=" << TypeID << ", Imm=" << Imm
<< ")\n");
break;
}

case GIR_ComplexRenderer: {
int64_t InsnID = MatchTable[CurrentIdx++];
int64_t RendererID = MatchTable[CurrentIdx++];
Expand Down Expand Up @@ -1239,8 +1277,11 @@ bool GIMatchTableExecutor::executeMatchTable(
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_EraseFromParent(MIs["
<< InsnID << "])\n");
if (Observer)
Observer->erasingInstr(*MI);
// If we're erasing the insertion point, ensure we don't leave a dangling
// pointer in the builder.
if (Builder.getInsertPt() == MI)
Builder.setInsertPt(*MI->getParent(), ++MI->getIterator());
Observer.erasingInstr(*MI);
MI->eraseFromParent();
break;
}
Expand All @@ -1250,7 +1291,7 @@ bool GIMatchTableExecutor::executeMatchTable(
int64_t TypeID = MatchTable[CurrentIdx++];

State.TempRegisters[TempRegID] =
MRI.createGenericVirtualRegister(ExecInfo.TypeObjects[TypeID]);
MRI.createGenericVirtualRegister(getTypeFromIdx(TypeID));
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": TempRegs[" << TempRegID
<< "] = GIR_MakeTempReg(" << TypeID << ")\n");
Expand All @@ -1269,11 +1310,9 @@ bool GIMatchTableExecutor::executeMatchTable(

Register Old = State.MIs[OldInsnID]->getOperand(OldOpIdx).getReg();
Register New = State.MIs[NewInsnID]->getOperand(NewOpIdx).getReg();
if (Observer)
Observer->changingAllUsesOfReg(MRI, Old);
Observer.changingAllUsesOfReg(MRI, Old);
MRI.replaceRegWith(Old, New);
if (Observer)
Observer->finishedChangingAllUsesOfReg();
Observer.finishedChangingAllUsesOfReg();
break;
}
case GIR_ReplaceRegWithTempReg: {
Expand All @@ -1288,11 +1327,9 @@ bool GIMatchTableExecutor::executeMatchTable(

Register Old = State.MIs[OldInsnID]->getOperand(OldOpIdx).getReg();
Register New = State.TempRegisters[TempRegID];
if (Observer)
Observer->changingAllUsesOfReg(MRI, Old);
Observer.changingAllUsesOfReg(MRI, Old);
MRI.replaceRegWith(Old, New);
if (Observer)
Observer->finishedChangingAllUsesOfReg();
Observer.finishedChangingAllUsesOfReg();
break;
}
case GIR_Coverage: {
Expand All @@ -1309,11 +1346,7 @@ bool GIMatchTableExecutor::executeMatchTable(
case GIR_Done:
DEBUG_WITH_TYPE(TgtExecutor::getName(),
dbgs() << CurrentIdx << ": GIR_Done\n");
if (Observer) {
for (MachineInstr *MI : OutMIs)
Observer->createdInstr(*MI);
}
propagateFlags(OutMIs);
propagateFlags();
return true;
default:
llvm_unreachable("Unexpected command");
Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/SaveAndRestore.h"

namespace llvm {

Expand Down Expand Up @@ -364,6 +365,15 @@ class MachineIRBuilder {
State.Observer = &Observer;
}

GISelChangeObserver *getChangeObserver() const { return State.Observer; }

// Replaces the change observer with \p Observer and returns an object that
// restores the old Observer on destruction.
SaveAndRestore<GISelChangeObserver *>
setTemporaryChangeObserver(GISelChangeObserver &Observer) {
return SaveAndRestore<GISelChangeObserver *>(State.Observer, &Observer);
}

void stopObservingChanges() { State.Observer = nullptr; }

bool isObservingChanges() const { return State.Observer != nullptr; }
Expand Down
Loading