Skip to content

[SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension #119862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
2 changes: 2 additions & 0 deletions llvm/docs/SPIRVUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
* - ``SPV_INTEL_cache_controls``
- Allows cache control information to be applied to memory access instructions.
* - ``SPV_INTEL_float_controls2``
- Adds execution modes and decorations to control floating-point computations.
* - ``SPV_INTEL_function_pointers``
- Allows translation of function pointers.
* - ``SPV_INTEL_inline_assembly``
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
// If we define an output, and have at least one other argument.
if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
// Check if we define an ID, and take a type as operand 1.
auto &DefOpInfo = MCDesc.operands()[0];
auto &FirstArgOpInfo = MCDesc.operands()[1];
return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
return MCDesc.operands()[0].RegClass >= 0 &&
MCDesc.operands()[1].RegClass >= 0 &&
MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
}
return false;
}
Expand Down
23 changes: 18 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ using namespace InstructionSet;

namespace SPIRV {
/// Parses the name part of the demangled builtin call.
std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
std::string lookupBuiltinNameHelper(StringRef DemangledCall,
std::string *Postfix) {
const static std::string PassPrefix = "(anonymous namespace)::";
std::string BuiltinName;
// Itanium Demangler result may have "(anonymous namespace)::" prefix
Expand Down Expand Up @@ -231,10 +232,13 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
"Convert|"
"UConvert|SConvert|FConvert|SatConvert).*)_R.*");
"UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
std::smatch Match;
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 2)
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
BuiltinName = Match[1].str();
if (Postfix)
*Postfix = Match[3].str();
}

return BuiltinName;
}
Expand Down Expand Up @@ -583,6 +587,15 @@ static Register buildScopeReg(Register CLScopeRegister,
return buildConstantIntReg32(Scope, MIRBuilder, GR);
}

static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
SPIRVGlobalRegistry *GR) {
if (MRI->getRegClassOrNull(Reg))
return;
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
MRI->setRegClass(Reg,
SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
}

static Register buildMemSemanticsReg(Register SemanticsRegister,
Register PtrRegister, unsigned &Semantics,
MachineIRBuilder &MIRBuilder,
Expand Down Expand Up @@ -1160,7 +1173,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
MIB.addUse(Call->Arguments[i]);
MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
setRegClassIfNull(Call->Arguments[i], MRI, GR);
}
insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
MIRBuilder.getMF().getRegInfo());
Expand All @@ -1176,7 +1189,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
MIB.addImm(GroupBuiltin->GroupOperation);
if (Call->Arguments.size() > 0) {
MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
setRegClassIfNull(Call->Arguments[0], MRI, GR);
if (VecReg.isValid())
MIB.addUse(VecReg);
else
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
namespace llvm {
namespace SPIRV {
/// Parses the name part of the demangled builtin call.
std::string lookupBuiltinNameHelper(StringRef DemangledCall);
std::string lookupBuiltinNameHelper(StringRef DemangledCall,
std::string *Postfix = nullptr);
/// Lowers a builtin function call using the provided \p DemangledCall skeleton
/// and external instruction \p Set.
///
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
{"SPV_INTEL_cache_controls",
SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
{"SPV_INTEL_float_controls2",
SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
{"SPV_INTEL_global_variable_fpga_decorations",
SPIRV::Extension::Extension::
SPV_INTEL_global_variable_fpga_decorations},
Expand Down
98 changes: 97 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ class SPIRVEmitIntrinsics
bool processFunctionPointers(Module &M);
void parseFunDeclarations(Module &M);

void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);

public:
static char ID;
SPIRVEmitIntrinsics() : ModulePass(ID) {
Expand Down Expand Up @@ -1291,6 +1293,37 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
}
}

static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
IRBuilder<> &B) {
LLVMContext &Ctx = I->getContext();
setInsertPointAfterDef(B, I);
B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
{I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
}

static void createRoundingModeDecoration(Instruction *I,
unsigned RoundingModeDeco,
IRBuilder<> &B) {
LLVMContext &Ctx = I->getContext();
Type *Int32Ty = Type::getInt32Ty(Ctx);
MDNode *RoundingModeNode = MDNode::get(
Ctx,
{ConstantAsMetadata::get(
ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
createDecorationIntrinsic(I, RoundingModeNode, B);
}

static void createSaturatedConversionDecoration(Instruction *I,
IRBuilder<> &B) {
LLVMContext &Ctx = I->getContext();
Type *Int32Ty = Type::getInt32Ty(Ctx);
MDNode *SaturatedConversionNode =
MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
Int32Ty, SPIRV::Decoration::SaturatedConversion))});
createDecorationIntrinsic(I, SaturatedConversionNode, B);
}

Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
if (!Call.isInlineAsm())
return &Call;
Expand All @@ -1312,6 +1345,40 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
return &Call;
}

// Use a tip about rounding mode to create a decoration.
void SPIRVEmitIntrinsics::useRoundingMode(ConstrainedFPIntrinsic *FPI,
IRBuilder<> &B) {
std::optional<RoundingMode> RM = FPI->getRoundingMode();
if (!RM.has_value())
return;
unsigned RoundingModeDeco = std::numeric_limits<unsigned>::max();
switch (RM.value()) {
default:
// ignore unknown rounding modes
break;
case RoundingMode::NearestTiesToEven:
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTE;
break;
case RoundingMode::TowardNegative:
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTN;
break;
case RoundingMode::TowardPositive:
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTP;
break;
case RoundingMode::TowardZero:
RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
break;
case RoundingMode::Dynamic:
case RoundingMode::NearestTiesToAway:
// TODO: check if supported
break;
}
if (RoundingModeDeco == std::numeric_limits<unsigned>::max())
return;
// Convert the tip about rounding mode into a decoration record.
createRoundingModeDecoration(FPI, RoundingModeDeco, B);
}

Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
BasicBlock *ParentBB = I.getParent();
IRBuilder<> B(ParentBB);
Expand Down Expand Up @@ -1809,6 +1876,18 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
return true;
}

static unsigned roundingModeMDToDecorationConst(StringRef S) {
if (S == "rte")
return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
if (S == "rtz")
return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
if (S == "rtp")
return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
if (S == "rtn")
return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
return std::numeric_limits<unsigned>::max();
}

void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
IRBuilder<> &B) {
// TODO: extend the list of functions with known result types
Expand All @@ -1826,8 +1905,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
Function *CalledF = CI->getCalledFunction();
std::string DemangledName =
getOclOrSpirvBuiltinDemangledName(CalledF->getName());
std::string Postfix;
if (DemangledName.length() > 0)
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
auto ResIt = ResTypeWellKnown.find(DemangledName);
if (ResIt != ResTypeWellKnown.end()) {
IsKnown = true;
Expand All @@ -1839,6 +1919,19 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
break;
}
}
// check if a floating rounding mode info is present
StringRef S = Postfix;
SmallVector<StringRef, 8> Parts;
S.split(Parts, "_", -1, false);
if (Parts.size() > 1) {
// Convert the info about rounding mode into a decoration record.
unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
createRoundingModeDecoration(CI, RoundingModeDeco, B);
// Check if the SaturatedConversion info is present.
if (Parts[1] == "sat")
createSaturatedConversionDecoration(CI, B);
}
}
}

Expand Down Expand Up @@ -2264,6 +2357,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
// already, and force it to be i8 if not
if (Postpone && !GR->findAssignPtrTypeInstr(I))
insertAssignPtrTypeIntrs(I, B, true);

if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I))
useRoundingMode(FPI, B);
}

// Pass backward: use instructions results to specify/update/cast operands
Expand Down
14 changes: 7 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
Width = adjustOpTypeIntWidth(Width);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
}
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
if (ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
MIRBuilder.buildInstr(SPIRV::OpExtension)
.addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
MIRBuilder.buildInstr(SPIRV::OpCapability)
.addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
}
return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -491,23 +491,29 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;

defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;

defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;

defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;

defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;

def OpSMod: BinOp<"OpSMod", 139>;

defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;

def OpFMod: BinOp<"OpFMod", 141>;

def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;
Expand Down
38 changes: 37 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
/// We need to keep track of the number we give to anonymous global values to
/// generate the same name every time when this is needed.
mutable DenseMap<const GlobalValue *, unsigned> UnnamedGlobalIDs;
SmallPtrSet<MachineInstr *, 8> DeadMIs;

public:
SPIRVInstructionSelector(const SPIRVTargetMachine &TM,
Expand Down Expand Up @@ -382,6 +383,24 @@ static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI);
// Defined in SPIRVLegalizerInfo.cpp.
extern bool isTypeFoldingSupported(unsigned Opcode);

bool isDead(const MachineInstr &MI, const MachineRegisterInfo &MRI) {
for (const auto &MO : MI.all_defs()) {
Register Reg = MO.getReg();
if (Reg.isPhysical() || !MRI.use_nodbg_empty(Reg))
return false;
}
if (MI.getOpcode() == TargetOpcode::LOCAL_ESCAPE || MI.isFakeUse() ||
MI.isLifetimeMarker())
return false;
if (MI.isPHI())
return true;
if (MI.mayStore() || MI.isCall() ||
(MI.mayLoad() && MI.hasOrderedMemoryRef()) || MI.isPosition() ||
MI.isDebugInstr() || MI.isTerminator() || MI.isJumpTableDebugInfo())
return false;
return true;
}

bool SPIRVInstructionSelector::select(MachineInstr &I) {
resetVRegsType(*I.getParent()->getParent());

Expand All @@ -404,8 +423,11 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
}
});
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
if (Res)
if (Res) {
if (!isTriviallyDead(*Def, *MRI) && isDead(*Def, *MRI))
DeadMIs.insert(Def);
return Res;
}
}
MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
MRI->replaceRegWith(SrcReg, DstReg);
Expand All @@ -418,6 +440,15 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
}

if (DeadMIs.contains(&I)) {
// if the instruction has been already made dead by folding it away
// erase it
LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
salvageDebugInfo(*MRI, I);
I.eraseFromParent();
return true;
}

if (I.getNumOperands() != I.getNumExplicitOperands()) {
LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n");
return false;
Expand Down Expand Up @@ -557,9 +588,13 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_UCMP:
return selectSUCmp(ResVReg, ResType, I, false);

case TargetOpcode::G_STRICT_FMA:
case TargetOpcode::G_FMA:
return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);

case TargetOpcode::G_STRICT_FLDEXP:
return selectExtInst(ResVReg, ResType, I, CL::ldexp);

case TargetOpcode::G_FPOW:
return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
case TargetOpcode::G_FPOWI:
Expand Down Expand Up @@ -618,6 +653,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
case TargetOpcode::G_FTANH:
return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);

case TargetOpcode::G_STRICT_FSQRT:
case TargetOpcode::G_FSQRT:
return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);

Expand Down
Loading
Loading