Skip to content

Commit 378630b

Browse files
[SPIR-V] Support cl_ext_float_atomics and fix errors in definition of atomic_fetch_*_explicit builtins (#96767)
This PR: * supports cl_ext_float_atomics by mapping atomic_fetch_add and atomic_fetch_sub applied to float arguments to the corresponding instructions from SPV_EXT_shader_atomic_float*_add, and * fix errors in definition of atomic_fetch_*_explicit builtins by fixing a valid number of arguments.
1 parent bb50bc2 commit 378630b

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ static bool buildAtomicCompareExchangeInst(
765765
return true;
766766
}
767767

768-
/// Helper function for building an atomic load instruction.
768+
/// Helper function for building atomic instructions.
769769
static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
770770
MachineIRBuilder &MIRBuilder,
771771
SPIRVGlobalRegistry *GR) {
@@ -790,13 +790,36 @@ static bool buildAtomicRMWInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
790790
MemSemanticsReg = buildMemSemanticsReg(MemSemanticsReg, PtrRegister,
791791
Semantics, MIRBuilder, GR);
792792
MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
793+
Register ValueReg = Call->Arguments[1];
794+
Register ValueTypeReg = GR->getSPIRVTypeID(Call->ReturnType);
795+
// support cl_ext_float_atomics
796+
if (Call->ReturnType->getOpcode() == SPIRV::OpTypeFloat) {
797+
if (Opcode == SPIRV::OpAtomicIAdd) {
798+
Opcode = SPIRV::OpAtomicFAddEXT;
799+
} else if (Opcode == SPIRV::OpAtomicISub) {
800+
// Translate OpAtomicISub applied to a floating type argument to
801+
// OpAtomicFAddEXT with the negative value operand
802+
Opcode = SPIRV::OpAtomicFAddEXT;
803+
Register NegValueReg =
804+
MRI->createGenericVirtualRegister(MRI->getType(ValueReg));
805+
MRI->setRegClass(NegValueReg, &SPIRV::IDRegClass);
806+
GR->assignSPIRVTypeToVReg(Call->ReturnType, NegValueReg,
807+
MIRBuilder.getMF());
808+
MIRBuilder.buildInstr(TargetOpcode::G_FNEG)
809+
.addDef(NegValueReg)
810+
.addUse(ValueReg);
811+
insertAssignInstr(NegValueReg, nullptr, Call->ReturnType, GR, MIRBuilder,
812+
MIRBuilder.getMF().getRegInfo());
813+
ValueReg = NegValueReg;
814+
}
815+
}
793816
MIRBuilder.buildInstr(Opcode)
794817
.addDef(Call->ReturnRegister)
795-
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
818+
.addUse(ValueTypeReg)
796819
.addUse(PtrRegister)
797820
.addUse(ScopeRegister)
798821
.addUse(MemSemanticsReg)
799-
.addUse(Call->Arguments[1]);
822+
.addUse(ValueReg);
800823
return true;
801824
}
802825

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -606,11 +606,11 @@ defm : DemangledNativeBuiltin<"atomic_fetch_sub", OpenCL_std, Atomic, 2, 4, OpAt
606606
defm : DemangledNativeBuiltin<"atomic_fetch_or", OpenCL_std, Atomic, 2, 4, OpAtomicOr>;
607607
defm : DemangledNativeBuiltin<"atomic_fetch_xor", OpenCL_std, Atomic, 2, 4, OpAtomicXor>;
608608
defm : DemangledNativeBuiltin<"atomic_fetch_and", OpenCL_std, Atomic, 2, 4, OpAtomicAnd>;
609-
defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicIAdd>;
610-
defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicISub>;
611-
defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicOr>;
612-
defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicXor>;
613-
defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 4, 6, OpAtomicAnd>;
609+
defm : DemangledNativeBuiltin<"atomic_fetch_add_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicIAdd>;
610+
defm : DemangledNativeBuiltin<"atomic_fetch_sub_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicISub>;
611+
defm : DemangledNativeBuiltin<"atomic_fetch_or_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicOr>;
612+
defm : DemangledNativeBuiltin<"atomic_fetch_xor_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicXor>;
613+
defm : DemangledNativeBuiltin<"atomic_fetch_and_explicit", OpenCL_std, Atomic, 3, 4, OpAtomicAnd>;
614614
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set", OpenCL_std, Atomic, 1, 1, OpAtomicFlagTestAndSet>;
615615
defm : DemangledNativeBuiltin<"__spirv_AtomicFlagTestAndSet", OpenCL_std, Atomic, 3, 3, OpAtomicFlagTestAndSet>;
616616
defm : DemangledNativeBuiltin<"atomic_flag_test_and_set_explicit", OpenCL_std, Atomic, 2, 3, OpAtomicFlagTestAndSet>;
@@ -1097,8 +1097,6 @@ multiclass DemangledAtomicFloatingBuiltin<string name, bits<8> minNumArgs, bits<
10971097
defm : DemangledAtomicFloatingBuiltin<"AddEXT", 4, 4, OpAtomicFAddEXT>;
10981098
defm : DemangledAtomicFloatingBuiltin<"MinEXT", 4, 4, OpAtomicFMinEXT>;
10991099
defm : DemangledAtomicFloatingBuiltin<"MaxEXT", 4, 4, OpAtomicFMaxEXT>;
1100-
// TODO: add support for cl_ext_float_atomics to enable performing atomic operations
1101-
// on floating-point numbers in memory (float arguments for atomic_fetch_add, ...)
11021100

11031101
//===----------------------------------------------------------------------===//
11041102
// Class defining a sub group builtin that should be translated into a

llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_shader_atomic_float_add/atomicrmw_faddfsub_float.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111
; CHECK-DAG: %[[Const0:[0-9]+]] = OpConstant %[[TyFP32]] 0
1212
; CHECK-DAG: %[[Const42:[0-9]+]] = OpConstant %[[TyFP32]] 42
1313
; CHECK-DAG: %[[ScopeDevice:[0-9]+]] = OpConstant %[[TyInt32]] 1
14+
; CHECK-DAG: %[[ScopeWorkgroup:[0-9]+]] = OpConstant %[[TyInt32]] 2
1415
; CHECK-DAG: %[[MemSeqCst:[0-9]+]] = OpConstant %[[TyInt32]] 16
16+
; CHECK-DAG: %[[WorkgroupMemory:[0-9]+]] = OpConstant %[[TyInt32]] 512
1517
; CHECK-DAG: %[[TyFP32Ptr:[0-9]+]] = OpTypePointer {{[a-zA-Z]+}} %[[TyFP32]]
1618
; CHECK-DAG: %[[DblPtr:[0-9]+]] = OpVariable %[[TyFP32Ptr]] {{[a-zA-Z]+}} %[[Const0]]
1719
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
1820
; CHECK: %[[Const42Neg:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
1921
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42Neg]]
2022
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeDevice]] %[[MemSeqCst]] %[[Const42]]
23+
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Const42]]
24+
; CHECK: %[[Neg42:[0-9]+]] = OpFNegate %[[TyFP32]] %[[Const42]]
25+
; CHECK: OpAtomicFAddEXT %[[TyFP32]] %[[DblPtr]] %[[ScopeWorkgroup]] %[[WorkgroupMemory]] %[[Neg42]]
2126

2227
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
2328
target triple = "spir64"
@@ -39,5 +44,15 @@ entry:
3944

4045
declare dso_local spir_func float @_Z21__spirv_AtomicFAddEXT(ptr addrspace(1), i32, i32, float)
4146

47+
define dso_local spir_func void @test3() local_unnamed_addr {
48+
entry:
49+
%r1 = tail call spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
50+
%r2 = tail call spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1) @f, float 42.000000e+00, i32 0)
51+
ret void
52+
}
53+
54+
declare spir_func float @_Z25atomic_fetch_add_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
55+
declare spir_func float @_Z25atomic_fetch_sub_explicitPU3AS1VU7_Atomicff12memory_order(ptr addrspace(1), float, i32)
56+
4257
!llvm.module.flags = !{!0}
4358
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)