Skip to content

Commit 978de2d

Browse files
[SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension (llvm#119862)
This PR adds the following features: * saturation and float rounding mode decorations, * arithmetic constrained floating-point intrinsics (strict_fadd, strict_fsub, strict_fmul, strict_fdiv, strict_frem, strict_fma and strict_fldexp), * and SPV_INTEL_float_controls2 extension, * using recent improvements of emit-intrinsics step, this PR also simplifies pre- and post-legalizer steps and improves instruction selection.
1 parent ace87ec commit 978de2d

18 files changed

+395
-80
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
159159
- Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
160160
* - ``SPV_INTEL_cache_controls``
161161
- Allows cache control information to be applied to memory access instructions.
162+
* - ``SPV_INTEL_float_controls2``
163+
- Adds execution modes and decorations to control floating-point computations.
162164
* - ``SPV_INTEL_function_pointers``
163165
- Allows translation of function pointers.
164166
* - ``SPV_INTEL_inline_assembly``

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
6565
// If we define an output, and have at least one other argument.
6666
if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
6767
// Check if we define an ID, and take a type as operand 1.
68-
auto &DefOpInfo = MCDesc.operands()[0];
69-
auto &FirstArgOpInfo = MCDesc.operands()[1];
70-
return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
71-
DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
72-
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
68+
return MCDesc.operands()[0].RegClass >= 0 &&
69+
MCDesc.operands()[1].RegClass >= 0 &&
70+
MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
71+
MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
7372
}
7473
return false;
7574
}

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ using namespace InstructionSet;
173173

174174
namespace SPIRV {
175175
/// Parses the name part of the demangled builtin call.
176-
std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
176+
std::string lookupBuiltinNameHelper(StringRef DemangledCall,
177+
std::string *Postfix) {
177178
const static std::string PassPrefix = "(anonymous namespace)::";
178179
std::string BuiltinName;
179180
// Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -231,10 +232,13 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
231232
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
232233
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
233234
"Convert|"
234-
"UConvert|SConvert|FConvert|SatConvert).*)_R.*");
235+
"UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
235236
std::smatch Match;
236-
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 2)
237+
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
237238
BuiltinName = Match[1].str();
239+
if (Postfix)
240+
*Postfix = Match[3].str();
241+
}
238242

239243
return BuiltinName;
240244
}
@@ -583,6 +587,15 @@ static Register buildScopeReg(Register CLScopeRegister,
583587
return buildConstantIntReg32(Scope, MIRBuilder, GR);
584588
}
585589

590+
static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
591+
SPIRVGlobalRegistry *GR) {
592+
if (MRI->getRegClassOrNull(Reg))
593+
return;
594+
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
595+
MRI->setRegClass(Reg,
596+
SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
597+
}
598+
586599
static Register buildMemSemanticsReg(Register SemanticsRegister,
587600
Register PtrRegister, unsigned &Semantics,
588601
MachineIRBuilder &MIRBuilder,
@@ -1160,7 +1173,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
11601173
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
11611174
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
11621175
MIB.addUse(Call->Arguments[i]);
1163-
MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
1176+
setRegClassIfNull(Call->Arguments[i], MRI, GR);
11641177
}
11651178
insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
11661179
MIRBuilder.getMF().getRegInfo());
@@ -1176,7 +1189,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
11761189
MIB.addImm(GroupBuiltin->GroupOperation);
11771190
if (Call->Arguments.size() > 0) {
11781191
MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
1179-
MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
1192+
setRegClassIfNull(Call->Arguments[0], MRI, GR);
11801193
if (VecReg.isValid())
11811194
MIB.addUse(VecReg);
11821195
else

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
namespace llvm {
2121
namespace SPIRV {
2222
/// Parses the name part of the demangled builtin call.
23-
std::string lookupBuiltinNameHelper(StringRef DemangledCall);
23+
std::string lookupBuiltinNameHelper(StringRef DemangledCall,
24+
std::string *Postfix = nullptr);
2425
/// Lowers a builtin function call using the provided \p DemangledCall skeleton
2526
/// and external instruction \p Set.
2627
///

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
3636
SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
3737
{"SPV_INTEL_cache_controls",
3838
SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
39+
{"SPV_INTEL_float_controls2",
40+
SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
3941
{"SPV_INTEL_global_variable_fpga_decorations",
4042
SPIRV::Extension::Extension::
4143
SPV_INTEL_global_variable_fpga_decorations},

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ class SPIRVEmitIntrinsics
216216
bool processFunctionPointers(Module &M);
217217
void parseFunDeclarations(Module &M);
218218

219+
void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
220+
219221
public:
220222
static char ID;
221223
SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -1291,6 +1293,37 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
12911293
}
12921294
}
12931295

1296+
static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
1297+
IRBuilder<> &B) {
1298+
LLVMContext &Ctx = I->getContext();
1299+
setInsertPointAfterDef(B, I);
1300+
B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
1301+
{I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
1302+
}
1303+
1304+
static void createRoundingModeDecoration(Instruction *I,
1305+
unsigned RoundingModeDeco,
1306+
IRBuilder<> &B) {
1307+
LLVMContext &Ctx = I->getContext();
1308+
Type *Int32Ty = Type::getInt32Ty(Ctx);
1309+
MDNode *RoundingModeNode = MDNode::get(
1310+
Ctx,
1311+
{ConstantAsMetadata::get(
1312+
ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
1313+
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
1314+
createDecorationIntrinsic(I, RoundingModeNode, B);
1315+
}
1316+
1317+
static void createSaturatedConversionDecoration(Instruction *I,
1318+
IRBuilder<> &B) {
1319+
LLVMContext &Ctx = I->getContext();
1320+
Type *Int32Ty = Type::getInt32Ty(Ctx);
1321+
MDNode *SaturatedConversionNode =
1322+
MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
1323+
Int32Ty, SPIRV::Decoration::SaturatedConversion))});
1324+
createDecorationIntrinsic(I, SaturatedConversionNode, B);
1325+
}
1326+
12941327
Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
12951328
if (!Call.isInlineAsm())
12961329
return &Call;
@@ -1312,6 +1345,40 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
13121345
return &Call;
13131346
}
13141347

1348+
// Use a tip about rounding mode to create a decoration.
1349+
void SPIRVEmitIntrinsics::useRoundingMode(ConstrainedFPIntrinsic *FPI,
1350+
IRBuilder<> &B) {
1351+
std::optional<RoundingMode> RM = FPI->getRoundingMode();
1352+
if (!RM.has_value())
1353+
return;
1354+
unsigned RoundingModeDeco = std::numeric_limits<unsigned>::max();
1355+
switch (RM.value()) {
1356+
default:
1357+
// ignore unknown rounding modes
1358+
break;
1359+
case RoundingMode::NearestTiesToEven:
1360+
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTE;
1361+
break;
1362+
case RoundingMode::TowardNegative:
1363+
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTN;
1364+
break;
1365+
case RoundingMode::TowardPositive:
1366+
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTP;
1367+
break;
1368+
case RoundingMode::TowardZero:
1369+
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
1370+
break;
1371+
case RoundingMode::Dynamic:
1372+
case RoundingMode::NearestTiesToAway:
1373+
// TODO: check if supported
1374+
break;
1375+
}
1376+
if (RoundingModeDeco == std::numeric_limits<unsigned>::max())
1377+
return;
1378+
// Convert the tip about rounding mode into a decoration record.
1379+
createRoundingModeDecoration(FPI, RoundingModeDeco, B);
1380+
}
1381+
13151382
Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
13161383
BasicBlock *ParentBB = I.getParent();
13171384
IRBuilder<> B(ParentBB);
@@ -1809,6 +1876,18 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
18091876
return true;
18101877
}
18111878

1879+
static unsigned roundingModeMDToDecorationConst(StringRef S) {
1880+
if (S == "rte")
1881+
return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
1882+
if (S == "rtz")
1883+
return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
1884+
if (S == "rtp")
1885+
return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
1886+
if (S == "rtn")
1887+
return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
1888+
return std::numeric_limits<unsigned>::max();
1889+
}
1890+
18121891
void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
18131892
IRBuilder<> &B) {
18141893
// TODO: extend the list of functions with known result types
@@ -1826,8 +1905,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
18261905
Function *CalledF = CI->getCalledFunction();
18271906
std::string DemangledName =
18281907
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
1908+
std::string Postfix;
18291909
if (DemangledName.length() > 0)
1830-
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
1910+
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
18311911
auto ResIt = ResTypeWellKnown.find(DemangledName);
18321912
if (ResIt != ResTypeWellKnown.end()) {
18331913
IsKnown = true;
@@ -1839,6 +1919,19 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
18391919
break;
18401920
}
18411921
}
1922+
// check if a floating rounding mode info is present
1923+
StringRef S = Postfix;
1924+
SmallVector<StringRef, 8> Parts;
1925+
S.split(Parts, "_", -1, false);
1926+
if (Parts.size() > 1) {
1927+
// Convert the info about rounding mode into a decoration record.
1928+
unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
1929+
if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
1930+
createRoundingModeDecoration(CI, RoundingModeDeco, B);
1931+
// Check if the SaturatedConversion info is present.
1932+
if (Parts[1] == "sat")
1933+
createSaturatedConversionDecoration(CI, B);
1934+
}
18421935
}
18431936
}
18441937

@@ -2264,6 +2357,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
22642357
// already, and force it to be i8 if not
22652358
if (Postpone && !GR->findAssignPtrTypeInstr(I))
22662359
insertAssignPtrTypeIntrs(I, B, true);
2360+
2361+
if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I))
2362+
useRoundingMode(FPI, B);
22672363
}
22682364

22692365
// Pass backward: use instructions results to specify/update/cast operands

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
126126
Width = adjustOpTypeIntWidth(Width);
127127
const SPIRVSubtarget &ST =
128128
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
129-
if (ST.canUseExtension(
130-
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
131-
MIRBuilder.buildInstr(SPIRV::OpExtension)
132-
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
133-
MIRBuilder.buildInstr(SPIRV::OpCapability)
134-
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
135-
}
136129
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
130+
if (ST.canUseExtension(
131+
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
132+
MIRBuilder.buildInstr(SPIRV::OpExtension)
133+
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
134+
MIRBuilder.buildInstr(SPIRV::OpCapability)
135+
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
136+
}
137137
return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
138138
.addDef(createTypeVReg(MIRBuilder))
139139
.addImm(Width)

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,23 +491,29 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
491491
def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
492492
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
493493
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
494+
defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;
494495

495496
defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
496497
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
498+
defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;
497499

498500
defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
499501
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
502+
defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;
500503

501504
defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
502505
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
503506
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
507+
defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;
504508

505509
defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
506510
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
507511

508512
def OpSMod: BinOp<"OpSMod", 139>;
509513

510514
defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
515+
defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;
516+
511517
def OpFMod: BinOp<"OpFMod", 141>;
512518

513519
def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
6161
/// We need to keep track of the number we give to anonymous global values to
6262
/// generate the same name every time when this is needed.
6363
mutable DenseMap<const GlobalValue *, unsigned> UnnamedGlobalIDs;
64+
SmallPtrSet<MachineInstr *, 8> DeadMIs;
6465

6566
public:
6667
SPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -382,6 +383,24 @@ static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI);
382383
// Defined in SPIRVLegalizerInfo.cpp.
383384
extern bool isTypeFoldingSupported(unsigned Opcode);
384385

386+
bool isDead(const MachineInstr &MI, const MachineRegisterInfo &MRI) {
387+
for (const auto &MO : MI.all_defs()) {
388+
Register Reg = MO.getReg();
389+
if (Reg.isPhysical() || !MRI.use_nodbg_empty(Reg))
390+
return false;
391+
}
392+
if (MI.getOpcode() == TargetOpcode::LOCAL_ESCAPE || MI.isFakeUse() ||
393+
MI.isLifetimeMarker())
394+
return false;
395+
if (MI.isPHI())
396+
return true;
397+
if (MI.mayStore() || MI.isCall() ||
398+
(MI.mayLoad() && MI.hasOrderedMemoryRef()) || MI.isPosition() ||
399+
MI.isDebugInstr() || MI.isTerminator() || MI.isJumpTableDebugInfo())
400+
return false;
401+
return true;
402+
}
403+
385404
bool SPIRVInstructionSelector::select(MachineInstr &I) {
386405
resetVRegsType(*I.getParent()->getParent());
387406

@@ -404,8 +423,11 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
404423
}
405424
});
406425
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
407-
if (Res)
426+
if (Res) {
427+
if (!isTriviallyDead(*Def, *MRI) && isDead(*Def, *MRI))
428+
DeadMIs.insert(Def);
408429
return Res;
430+
}
409431
}
410432
MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
411433
MRI->replaceRegWith(SrcReg, DstReg);
@@ -418,6 +440,15 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
418440
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
419441
}
420442

443+
if (DeadMIs.contains(&I)) {
444+
// if the instruction has been already made dead by folding it away
445+
// erase it
446+
LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
447+
salvageDebugInfo(*MRI, I);
448+
I.eraseFromParent();
449+
return true;
450+
}
451+
421452
if (I.getNumOperands() != I.getNumExplicitOperands()) {
422453
LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n");
423454
return false;
@@ -557,9 +588,13 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
557588
case TargetOpcode::G_UCMP:
558589
return selectSUCmp(ResVReg, ResType, I, false);
559590

591+
case TargetOpcode::G_STRICT_FMA:
560592
case TargetOpcode::G_FMA:
561593
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
562594

595+
case TargetOpcode::G_STRICT_FLDEXP:
596+
return selectExtInst(ResVReg, ResType, I, CL::ldexp);
597+
563598
case TargetOpcode::G_FPOW:
564599
return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
565600
case TargetOpcode::G_FPOWI:
@@ -618,6 +653,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
618653
case TargetOpcode::G_FTANH:
619654
return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
620655

656+
case TargetOpcode::G_STRICT_FSQRT:
621657
case TargetOpcode::G_FSQRT:
622658
return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
623659

0 commit comments

Comments
 (0)