Skip to content

Commit 925768e

Browse files
Add support for atomic instruction on floating-point numbers (#81683)
This PR adds support for atomic instruction on floating-point numbers: * SPV_EXT_shader_atomic_float_add * SPV_EXT_shader_atomic_float_min_max * SPV_EXT_shader_atomic_float16_add and fixes asm printer output for half floating-type.
1 parent e488fe5 commit 925768e

18 files changed

+517
-21
lines changed

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "SPIRVInstPrinter.h"
1414
#include "SPIRV.h"
1515
#include "SPIRVBaseInfo.h"
16+
#include "SPIRVInstrInfo.h"
1617
#include "llvm/ADT/APFloat.h"
1718
#include "llvm/CodeGen/Register.h"
1819
#include "llvm/MC/MCAsmInfo.h"
@@ -50,6 +51,7 @@ void SPIRVInstPrinter::printRemainingVariableOps(const MCInst *MI,
5051
void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
5152
unsigned StartIndex,
5253
raw_ostream &O) {
54+
unsigned IsBitwidth16 = MI->getFlags() & SPIRV::ASM_PRINTER_WIDTH16;
5355
const unsigned NumVarOps = MI->getNumOperands() - StartIndex;
5456

5557
assert((NumVarOps == 1 || NumVarOps == 2) &&
@@ -65,7 +67,7 @@ void SPIRVInstPrinter::printOpConstantVarOps(const MCInst *MI,
6567
}
6668

6769
// Format and print float values.
68-
if (MI->getOpcode() == SPIRV::OpConstantF) {
70+
if (MI->getOpcode() == SPIRV::OpConstantF && IsBitwidth16 == 0) {
6971
APFloat FP = NumVarOps == 1 ? APFloat(APInt(32, Imm).bitsToFloat())
7072
: APFloat(APInt(64, Imm).bitsToDouble());
7173

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ struct IntelSubgroupsBuiltin {
9393
#define GET_IntelSubgroupsBuiltins_DECL
9494
#define GET_IntelSubgroupsBuiltins_IMPL
9595

96+
struct AtomicFloatingBuiltin {
97+
StringRef Name;
98+
uint32_t Opcode;
99+
};
100+
101+
#define GET_AtomicFloatingBuiltins_DECL
102+
#define GET_AtomicFloatingBuiltins_IMPL
103+
96104
struct GetBuiltin {
97105
StringRef Name;
98106
InstructionSet::InstructionSet Set;
@@ -402,7 +410,7 @@ getSPIRVMemSemantics(std::memory_order MemOrder) {
402410
case std::memory_order::memory_order_seq_cst:
403411
return SPIRV::MemorySemantics::SequentiallyConsistent;
404412
default:
405-
llvm_unreachable("Unknown CL memory scope");
413+
report_fatal_error("Unknown CL memory scope");
406414
}
407415
}
408416

@@ -419,7 +427,7 @@ static SPIRV::Scope::Scope getSPIRVScope(SPIRV::CLMemoryScope ClScope) {
419427
case SPIRV::CLMemoryScope::memory_scope_sub_group:
420428
return SPIRV::Scope::Subgroup;
421429
}
422-
llvm_unreachable("Unknown CL memory scope");
430+
report_fatal_error("Unknown CL memory scope");
423431
}
424432

425433
static Register buildConstantIntReg(uint64_t Val, MachineIRBuilder &MIRBuilder,
@@ -676,6 +684,38 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
676684
return true;
677685
}
678686

687+
/// Helper function for building an atomic floating-type instruction.
688+
static bool buildAtomicFloatingRMWInst(const SPIRV::IncomingCall *Call,
689+
unsigned Opcode,
690+
MachineIRBuilder &MIRBuilder,
691+
SPIRVGlobalRegistry *GR) {
692+
assert(Call->Arguments.size() == 4 &&
693+
"Wrong number of atomic floating-type builtin");
694+
695+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
696+
697+
Register PtrReg = Call->Arguments[0];
698+
MRI->setRegClass(PtrReg, &SPIRV::IDRegClass);
699+
700+
Register ScopeReg = Call->Arguments[1];
701+
MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
702+
703+
Register MemSemanticsReg = Call->Arguments[2];
704+
MRI->setRegClass(MemSemanticsReg, &SPIRV::IDRegClass);
705+
706+
Register ValueReg = Call->Arguments[3];
707+
MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
708+
709+
MIRBuilder.buildInstr(Opcode)
710+
.addDef(Call->ReturnRegister)
711+
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
712+
.addUse(PtrReg)
713+
.addUse(ScopeReg)
714+
.addUse(MemSemanticsReg)
715+
.addUse(ValueReg);
716+
return true;
717+
}
718+
679719
/// Helper function for building atomic flag instructions (e.g.
680720
/// OpAtomicFlagTestAndSet).
681721
static bool buildAtomicFlagInst(const SPIRV::IncomingCall *Call,
@@ -786,7 +826,7 @@ static unsigned getNumComponentsForDim(SPIRV::Dim::Dim dim) {
786826
case SPIRV::Dim::DIM_3D:
787827
return 3;
788828
default:
789-
llvm_unreachable("Cannot get num components for given Dim");
829+
report_fatal_error("Cannot get num components for given Dim");
790830
}
791831
}
792832

@@ -1157,6 +1197,23 @@ static bool generateAtomicInst(const SPIRV::IncomingCall *Call,
11571197
}
11581198
}
11591199

1200+
static bool generateAtomicFloatingInst(const SPIRV::IncomingCall *Call,
1201+
MachineIRBuilder &MIRBuilder,
1202+
SPIRVGlobalRegistry *GR) {
1203+
// Lookup the instruction opcode in the TableGen records.
1204+
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
1205+
unsigned Opcode = SPIRV::lookupAtomicFloatingBuiltin(Builtin->Name)->Opcode;
1206+
1207+
switch (Opcode) {
1208+
case SPIRV::OpAtomicFAddEXT:
1209+
case SPIRV::OpAtomicFMinEXT:
1210+
case SPIRV::OpAtomicFMaxEXT:
1211+
return buildAtomicFloatingRMWInst(Call, Opcode, MIRBuilder, GR);
1212+
default:
1213+
return false;
1214+
}
1215+
}
1216+
11601217
static bool generateBarrierInst(const SPIRV::IncomingCall *Call,
11611218
MachineIRBuilder &MIRBuilder,
11621219
SPIRVGlobalRegistry *GR) {
@@ -1311,7 +1368,7 @@ getSamplerAddressingModeFromBitmask(unsigned Bitmask) {
13111368
case SPIRV::CLK_ADDRESS_NONE:
13121369
return SPIRV::SamplerAddressingMode::None;
13131370
default:
1314-
llvm_unreachable("Unknown CL address mode");
1371+
report_fatal_error("Unknown CL address mode");
13151372
}
13161373
}
13171374

@@ -2021,6 +2078,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
20212078
return generateBuiltinVar(Call.get(), MIRBuilder, GR);
20222079
case SPIRV::Atomic:
20232080
return generateAtomicInst(Call.get(), MIRBuilder, GR);
2081+
case SPIRV::AtomicFloating:
2082+
return generateAtomicFloatingInst(Call.get(), MIRBuilder, GR);
20242083
case SPIRV::Barrier:
20252084
return generateBarrierInst(Call.get(), MIRBuilder, GR);
20262085
case SPIRV::Dot:
@@ -2089,7 +2148,7 @@ static Type *parseTypeString(const StringRef Name, LLVMContext &Context) {
20892148
return Type::getFloatTy(Context);
20902149
else if (Name.starts_with("half"))
20912150
return Type::getHalfTy(Context);
2092-
llvm_unreachable("Unable to recognize type!");
2151+
report_fatal_error("Unable to recognize type!");
20932152
}
20942153

20952154
//===----------------------------------------------------------------------===//

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def AsyncCopy : BuiltinGroup;
5555
def VectorLoadStore : BuiltinGroup;
5656
def LoadStore : BuiltinGroup;
5757
def IntelSubgroups : BuiltinGroup;
58+
def AtomicFloating : BuiltinGroup;
5859

5960
//===----------------------------------------------------------------------===//
6061
// Class defining a demangled builtin record. The information in the record
@@ -872,6 +873,44 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
872873
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
873874
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
874875

876+
//===----------------------------------------------------------------------===//
877+
// Class defining an atomic instruction on floating-point numbers.
878+
//
879+
// name is the demangled name of the given builtin.
880+
// opcode specifies the SPIR-V operation code of the generated instruction.
881+
//===----------------------------------------------------------------------===//
882+
class AtomicFloatingBuiltin<string name, Op operation> {
883+
string Name = name;
884+
Op Opcode = operation;
885+
}
886+
887+
// Table gathering all builtins for atomic instructions on floating-point numbers
888+
def AtomicFloatingBuiltins : GenericTable {
889+
let FilterClass = "AtomicFloatingBuiltin";
890+
let Fields = ["Name", "Opcode"];
891+
}
892+
893+
// Function to lookup builtins by their name and set.
894+
def lookupAtomicFloatingBuiltin : SearchIndex {
895+
let Table = AtomicFloatingBuiltins;
896+
let Key = ["Name"];
897+
}
898+
899+
// Multiclass used to define incoming demangled builtin records and
900+
// corresponding builtin records for atomic instructions on floating-point numbers.
901+
multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
902+
def : DemangledBuiltin<!strconcat("__spirv_AtomicF", name), OpenCL_std, AtomicFloating, minNumArgs, maxNumArgs>;
903+
def : AtomicFloatingBuiltin<!strconcat("__spirv_AtomicF", name), operation>;
904+
}
905+
906+
// SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max, SPV_EXT_shader_atomic_float16_add
907+
// Atomic add, min and max instruction on floating-point numbers:
908+
defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
909+
defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
910+
defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
911+
// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
912+
// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
913+
875914
//===----------------------------------------------------------------------===//
876915
// Class defining a sub group builtin that should be translated into a
877916
// SPIR-V instruction using the SPV_INTEL_subgroups extension.

llvm/lib/Target/SPIRV/SPIRVInstrInfo.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class SPIRVInstrInfo : public SPIRVGenInstrInfo {
5353
bool KillSrc) const override;
5454
bool expandPostRAPseudo(MachineInstr &MI) const override;
5555
};
56+
57+
namespace SPIRV {
58+
enum AsmComments {
59+
// It is a half type
60+
ASM_PRINTER_WIDTH16 = MachineInstr::TAsmComments
61+
};
62+
}; // namespace SPIRV
63+
5664
} // namespace llvm
5765

5866
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVINSTRINFO_H

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,9 @@ def OpAtomicAnd: AtomicOpVal<"OpAtomicAnd", 240>;
643643
def OpAtomicOr: AtomicOpVal<"OpAtomicOr", 241>;
644644
def OpAtomicXor: AtomicOpVal<"OpAtomicXor", 242>;
645645

646+
def OpAtomicFAddEXT: AtomicOpVal<"OpAtomicFAddEXT", 6035>;
647+
def OpAtomicFMinEXT: AtomicOpVal<"OpAtomicFMinEXT", 5614>;
648+
def OpAtomicFMaxEXT: AtomicOpVal<"OpAtomicFMaxEXT", 5615>;
646649

647650
def OpAtomicFlagTestAndSet: AtomicOp<"OpAtomicFlagTestAndSet", 318>;
648651
def OpAtomicFlagClear: Op<319, (outs), (ins ID:$ptr, ID:$sc, ID:$sem),

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
102102
bool selectMemOperation(Register ResVReg, MachineInstr &I) const;
103103

104104
bool selectAtomicRMW(Register ResVReg, const SPIRVType *ResType,
105-
MachineInstr &I, unsigned NewOpcode) const;
105+
MachineInstr &I, unsigned NewOpcode,
106+
unsigned NegateOpcode = 0) const;
106107

107108
bool selectAtomicCmpXchg(Register ResVReg, const SPIRVType *ResType,
108109
MachineInstr &I) const;
@@ -489,6 +490,17 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
489490
case TargetOpcode::G_ATOMIC_CMPXCHG:
490491
return selectAtomicCmpXchg(ResVReg, ResType, I);
491492

493+
case TargetOpcode::G_ATOMICRMW_FADD:
494+
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT);
495+
case TargetOpcode::G_ATOMICRMW_FSUB:
496+
// Translate G_ATOMICRMW_FSUB to OpAtomicFAddEXT with negative value operand
497+
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFAddEXT,
498+
SPIRV::OpFNegate);
499+
case TargetOpcode::G_ATOMICRMW_FMIN:
500+
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMinEXT);
501+
case TargetOpcode::G_ATOMICRMW_FMAX:
502+
return selectAtomicRMW(ResVReg, ResType, I, SPIRV::OpAtomicFMaxEXT);
503+
492504
case TargetOpcode::G_FENCE:
493505
return selectFence(I);
494506

@@ -686,7 +698,8 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
686698
bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
687699
const SPIRVType *ResType,
688700
MachineInstr &I,
689-
unsigned NewOpcode) const {
701+
unsigned NewOpcode,
702+
unsigned NegateOpcode) const {
690703
assert(I.hasOneMemOperand());
691704
const MachineMemOperand *MemOp = *I.memoperands_begin();
692705
uint32_t Scope = static_cast<uint32_t>(getScope(MemOp->getSyncScopeID()));
@@ -700,14 +713,24 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
700713
uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));
701714
Register MemSemReg = buildI32Constant(MemSem /*| ScSem*/, I);
702715

703-
return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
704-
.addDef(ResVReg)
705-
.addUse(GR.getSPIRVTypeID(ResType))
706-
.addUse(Ptr)
707-
.addUse(ScopeReg)
708-
.addUse(MemSemReg)
709-
.addUse(I.getOperand(2).getReg())
710-
.constrainAllUses(TII, TRI, RBI);
716+
bool Result = false;
717+
Register ValueReg = I.getOperand(2).getReg();
718+
if (NegateOpcode != 0) {
719+
// Translation with negative value operand is requested
720+
Register TmpReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
721+
Result |= selectUnOpWithSrc(TmpReg, ResType, I, ValueReg, NegateOpcode);
722+
ValueReg = TmpReg;
723+
}
724+
725+
Result |= BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(NewOpcode))
726+
.addDef(ResVReg)
727+
.addUse(GR.getSPIRVTypeID(ResType))
728+
.addUse(Ptr)
729+
.addUse(ScopeReg)
730+
.addUse(MemSemReg)
731+
.addUse(ValueReg)
732+
.constrainAllUses(TII, TRI, RBI);
733+
return Result;
711734
}
712735

713736
bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
125125

126126
auto allIntScalars = {s8, s16, s32, s64};
127127

128+
auto allFloatScalars = {s16, s32, s64};
129+
128130
auto allFloatScalarsAndVectors = {
129131
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
130132
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
@@ -205,6 +207,10 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
205207
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
206208
.legalForCartesianProduct(allIntScalars, allWritablePtrs);
207209

210+
getActionDefinitionsBuilder(
211+
{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
212+
.legalForCartesianProduct(allFloatScalars, allWritablePtrs);
213+
208214
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
209215
.legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
210216

llvm/lib/Target/SPIRV/SPIRVMCInstLower.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ using namespace llvm;
2323
void SPIRVMCInstLower::lower(const MachineInstr *MI, MCInst &OutMI,
2424
SPIRV::ModuleAnalysisInfo *MAI) const {
2525
OutMI.setOpcode(MI->getOpcode());
26+
// Propagate previously set flags
27+
OutMI.setFlags(MI->getAsmPrinterFlags());
2628
const MachineFunction *MF = MI->getMF();
2729
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
2830
const MachineOperand &MO = MI->getOperand(i);

0 commit comments

Comments
 (0)