Skip to content

Commit 7f79653

Browse files
improve type deduction for phi and call base
1 parent 91b9add commit 7f79653

File tree

3 files changed

+110
-67
lines changed

3 files changed

+110
-67
lines changed

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ class SPIRVAsmPrinter : public AsmPrinter {
7878
void outputExecutionMode(const Module &M);
7979
void outputAnnotations(const Module &M);
8080
void outputModuleSections();
81+
bool isHidden() {
82+
return MF->getFunction()
83+
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
84+
.isValid();
85+
}
8186

8287
void emitInstruction(const MachineInstr *MI) override;
8388
void emitFunctionEntryLabel() override {}
@@ -131,7 +136,7 @@ void SPIRVAsmPrinter::emitFunctionHeader() {
131136
TII = ST->getInstrInfo();
132137
const Function &F = MF->getFunction();
133138

134-
if (isVerbose()) {
139+
if (isVerbose() && !isHidden()) {
135140
OutStreamer->getCommentOS()
136141
<< "-- Begin function "
137142
<< GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n';
@@ -150,16 +155,17 @@ void SPIRVAsmPrinter::outputOpFunctionEnd() {
150155
// Emit OpFunctionEnd at the end of MF and clear BBNumToRegMap.
151156
void SPIRVAsmPrinter::emitFunctionBodyEnd() {
152157
// Do not emit anything if it's an internal service function.
153-
if (MF->getFunction()
154-
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
155-
.isValid())
158+
if (isHidden())
156159
return;
157-
158160
outputOpFunctionEnd();
159161
MAI->BBNumToRegMap.clear();
160162
}
161163

162164
void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
165+
// Do not emit anything if it's an internal service function.
166+
if (isHidden())
167+
return;
168+
163169
MCInst LabelInst;
164170
LabelInst.setOpcode(SPIRV::OpLabel);
165171
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -500,16 +500,6 @@ void SPIRVCallLowering::produceIndirectPtrTypes(
500500

501501
bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
502502
CallLoweringInfo &Info) const {
503-
// Ignore if called from the internal service function
504-
if (MIRBuilder.getMF()
505-
.getFunction()
506-
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
507-
.isValid()) {
508-
// insert a no-op
509-
MIRBuilder.buildTrap();
510-
return true;
511-
}
512-
513503
// Currently call returns should have single vregs.
514504
// TODO: handle the case of multiple registers.
515505
if (Info.OrigRet.Regs.size() > 1)
@@ -597,6 +587,16 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
597587
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
598588
}
599589

590+
// Ignore the call if it's called from the internal service function
591+
if (MIRBuilder.getMF()
592+
.getFunction()
593+
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
594+
.isValid()) {
595+
// insert a no-op
596+
MIRBuilder.buildTrap();
597+
return true;
598+
}
599+
600600
unsigned CallOp;
601601
if (Info.CB->isIndirectCall()) {
602602
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 89 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
386386
// Traverse User instructions to deduce an element pointer type of the operand.
387387
Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
388388
Value *Op, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8) {
389-
if (!Op || !isPointerTy(Op->getType()))
389+
if (!Op || !isPointerTy(Op->getType()) || isa<ConstantPointerNull>(Op) ||
390+
isa<UndefValue>(Op))
390391
return nullptr;
391392

392393
if (auto ElemTy = getPointeeType(Op->getType()))
@@ -483,12 +484,25 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
483484
if (isPointerTy(Op->getType()))
484485
Ty = deduceElementTypeHelper(Op, Visited, UnknownElemTypeI8);
485486
} else if (auto *Ref = dyn_cast<PHINode>(I)) {
486-
for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
487+
Type *BestTy = nullptr;
488+
unsigned MaxN = 1;
489+
DenseMap<Type *, unsigned> PhiTys;
490+
for (int i = Ref->getNumIncomingValues() - 1; i >= 0; --i) {
487491
Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited,
488492
UnknownElemTypeI8);
489-
if (Ty)
490-
break;
493+
if (!Ty)
494+
continue;
495+
auto It = PhiTys.try_emplace(Ty, 1);
496+
if (!It.second) {
497+
++It.first->second;
498+
if (It.first->second > MaxN) {
499+
MaxN = It.first->second;
500+
BestTy = Ty;
501+
}
502+
}
491503
}
504+
if (BestTy)
505+
Ty = BestTy;
492506
} else if (auto *Ref = dyn_cast<SelectInst>(I)) {
493507
for (Value *Op : {Ref->getTrueValue(), Ref->getFalseValue()}) {
494508
Ty = deduceElementTypeByUsersDeep(Op, Visited, UnknownElemTypeI8);
@@ -644,6 +658,62 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
644658
return nullptr;
645659
}
646660

661+
// Try to deduce element type for a call base. Returns false if this is an
662+
// indirect function invocation, and true otherwise.
663+
static bool deduceOperandElementTypeCalledFunction(
664+
SPIRVGlobalRegistry *GR, Instruction *I,
665+
SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
666+
SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
667+
Function *CalledF = CI->getCalledFunction();
668+
if (!CalledF)
669+
return false;
670+
std::string DemangledName =
671+
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
672+
if (DemangledName.length() > 0 &&
673+
!StringRef(DemangledName).starts_with("llvm.")) {
674+
auto [Grp, Opcode, ExtNo] =
675+
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
676+
if (Opcode == SPIRV::OpGroupAsyncCopy) {
677+
for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2; ++i) {
678+
Value *Op = CI->getArgOperand(i);
679+
if (!isPointerTy(Op->getType()))
680+
continue;
681+
++PtrCnt;
682+
if (Type *ElemTy = GR->findDeducedElementType(Op))
683+
KnownElemTy = ElemTy; // src will rewrite dest if both are defined
684+
Ops.push_back(std::make_pair(Op, i));
685+
}
686+
} else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
687+
if (CI->arg_size() < 2)
688+
return true;
689+
Value *Op = CI->getArgOperand(0);
690+
if (!isPointerTy(Op->getType()))
691+
return true;
692+
switch (Opcode) {
693+
case SPIRV::OpAtomicLoad:
694+
case SPIRV::OpAtomicCompareExchangeWeak:
695+
case SPIRV::OpAtomicCompareExchange:
696+
case SPIRV::OpAtomicExchange:
697+
case SPIRV::OpAtomicIAdd:
698+
case SPIRV::OpAtomicISub:
699+
case SPIRV::OpAtomicOr:
700+
case SPIRV::OpAtomicXor:
701+
case SPIRV::OpAtomicAnd:
702+
case SPIRV::OpAtomicUMin:
703+
case SPIRV::OpAtomicUMax:
704+
case SPIRV::OpAtomicSMin:
705+
case SPIRV::OpAtomicSMax: {
706+
KnownElemTy = getAtomicElemTy(GR, I, Op);
707+
if (!KnownElemTy)
708+
return true;
709+
Ops.push_back(std::make_pair(Op, 0));
710+
} break;
711+
}
712+
}
713+
}
714+
return true;
715+
}
716+
647717
// If the Instruction has Pointer operands with unresolved types, this function
648718
// tries to deduce them. If the Instruction has Pointer operands with known
649719
// types which differ from expected, this function tries to insert a bitcast to
@@ -749,53 +819,17 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
749819
KnownElemTy = ElemTy1;
750820
Ops.push_back(std::make_pair(Op0, 0));
751821
}
752-
} else if (auto *CI = dyn_cast<CallInst>(I)) {
753-
if (Function *CalledF = CI->getCalledFunction()) {
754-
std::string DemangledName =
755-
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
756-
if (DemangledName.length() > 0 &&
757-
!StringRef(DemangledName).starts_with("llvm.")) {
758-
auto [Grp, Opcode, ExtNo] =
759-
SPIRV::mapBuiltinToOpcode(DemangledName, InstrSet);
760-
if (Opcode == SPIRV::OpGroupAsyncCopy) {
761-
for (unsigned i = 0, PtrCnt = 0; i < CI->arg_size() && PtrCnt < 2;
762-
++i) {
763-
Value *Op = CI->getArgOperand(i);
764-
if (!isPointerTy(Op->getType()))
765-
continue;
766-
++PtrCnt;
767-
if (Type *ElemTy = GR->findDeducedElementType(Op))
768-
KnownElemTy = ElemTy; // src will rewrite dest if both are defined
769-
Ops.push_back(std::make_pair(Op, i));
770-
}
771-
} else if (Grp == SPIRV::Atomic || Grp == SPIRV::AtomicFloating) {
772-
if (CI->arg_size() < 2)
773-
return;
774-
Value *Op = CI->getArgOperand(0);
775-
if (!isPointerTy(Op->getType()))
776-
return;
777-
switch (Opcode) {
778-
case SPIRV::OpAtomicLoad:
779-
case SPIRV::OpAtomicCompareExchangeWeak:
780-
case SPIRV::OpAtomicCompareExchange:
781-
case SPIRV::OpAtomicExchange:
782-
case SPIRV::OpAtomicIAdd:
783-
case SPIRV::OpAtomicISub:
784-
case SPIRV::OpAtomicOr:
785-
case SPIRV::OpAtomicXor:
786-
case SPIRV::OpAtomicAnd:
787-
case SPIRV::OpAtomicUMin:
788-
case SPIRV::OpAtomicUMax:
789-
case SPIRV::OpAtomicSMin:
790-
case SPIRV::OpAtomicSMax: {
791-
KnownElemTy = getAtomicElemTy(GR, I, Op);
792-
if (!KnownElemTy)
793-
return;
794-
Ops.push_back(std::make_pair(Op, 0));
795-
} break;
796-
}
797-
}
798-
}
822+
} else if (CallInst *CI = dyn_cast<CallInst>(I)) {
823+
if (!CI->isIndirectCall()) {
824+
deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
825+
KnownElemTy);
826+
} else if (TM->getSubtarget<SPIRVSubtarget>(*F).canUseExtension(
827+
SPIRV::Extension::SPV_INTEL_function_pointers)) {
828+
Value *Op = CI->getCalledOperand();
829+
if (!Op || !isPointerTy(Op->getType()))
830+
return;
831+
Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
832+
KnownElemTy = CI->getFunctionType();
799833
}
800834
}
801835

@@ -846,7 +880,10 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
846880
B.getInt32(getPointerAddressSpace(OpTy))};
847881
CallInst *PtrCastI =
848882
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
849-
I->setOperand(OpIt.second, PtrCastI);
883+
if (OpIt.second == std::numeric_limits<unsigned>::max())
884+
dyn_cast<CallInst>(I)->setCalledOperand(PtrCastI);
885+
else
886+
I->setOperand(OpIt.second, PtrCastI);
850887
buildAssignPtr(B, KnownElemTy, PtrCastI);
851888
}
852889
}

0 commit comments

Comments
 (0)