Skip to content

Commit 0e34766

Browse files
[SPIR-V] Implement support of the SPV_EXT_arithmetic_fence SPIRV extension (#110500)
This PR implements support of the SPV_EXT_arithmetic_fence SPIRV extension: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/EXT/SPV_EXT_arithmetic_fence.html.
1 parent c538d5c commit 0e34766

File tree

7 files changed

+88
-0
lines changed

7 files changed

+88
-0
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
147147
- Adds atomic add instruction on floating-point numbers.
148148
* - ``SPV_EXT_shader_atomic_float_min_max``
149149
- Adds atomic min and max instruction on floating-point numbers.
150+
* - ``SPV_EXT_arithmetic_fence``
151+
- Adds an instruction that prevents fast-math optimizations between its argument and the expression that contains it.
150152
* - ``SPV_INTEL_arbitrary_precision_integers``
151153
- Allows generating arbitrary width integer types.
152154
* - ``SPV_INTEL_bfloat16_conversion``

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
2828
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
2929
{"SPV_EXT_shader_atomic_float_min_max",
3030
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
31+
{"SPV_EXT_arithmetic_fence",
32+
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
3133
{"SPV_INTEL_arbitrary_precision_integers",
3234
SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
3335
{"SPV_INTEL_cache_controls",

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,7 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
878878
"$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
879879
def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
880880
"$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
881+
882+
// SPV_EXT_arithmetic_fence
883+
def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
884+
"$res = OpArithmeticFenceEXT $type $target">;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2600,6 +2600,16 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
26002600
.addUse(I.getOperand(2).getReg())
26012601
.addUse(I.getOperand(3).getReg());
26022602
break;
2603+
case Intrinsic::arithmetic_fence:
2604+
if (STI.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
2605+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpArithmeticFenceEXT))
2606+
.addDef(ResVReg)
2607+
.addUse(GR.getSPIRVTypeID(ResType))
2608+
.addUse(I.getOperand(2).getReg());
2609+
else
2610+
BuildMI(BB, I, I.getDebugLoc(), TII.get(TargetOpcode::COPY), ResVReg)
2611+
.addUse(I.getOperand(2).getReg());
2612+
break;
26032613
case Intrinsic::spv_thread_id:
26042614
return selectSpvThreadId(ResVReg, ResType, I);
26052615
case Intrinsic::spv_fdot:

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,14 @@ void addInstrRequirements(const MachineInstr &MI,
12001200
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
12011201
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
12021202
break;
1203+
case SPIRV::OpArithmeticFenceEXT:
1204+
if (!ST.canUseExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence))
1205+
report_fatal_error("OpArithmeticFenceEXT requires the "
1206+
"following SPIR-V extension: SPV_EXT_arithmetic_fence",
1207+
false);
1208+
Reqs.addExtension(SPIRV::Extension::SPV_EXT_arithmetic_fence);
1209+
Reqs.addCapability(SPIRV::Capability::ArithmeticFenceEXT);
1210+
break;
12031211
default:
12041212
break;
12051213
}

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
303303
defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
304304
defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
305305
defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
306+
defm SPV_EXT_arithmetic_fence : ExtensionOperand<112>;
306307

307308
//===----------------------------------------------------------------------===//
308309
// Multiclass used to define Capabilities enum values and at the same time
@@ -480,6 +481,7 @@ defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_
480481
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
481482
defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
482483
defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;
484+
defm ArithmeticFenceEXT : CapabilityOperand<6144, 0, 0, [SPV_EXT_arithmetic_fence], []>;
483485

484486
//===----------------------------------------------------------------------===//
485487
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - | FileCheck %s --check-prefixes=CHECK-NOEXT
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-linux %s -o - --spirv-ext=+SPV_EXT_arithmetic_fence | FileCheck %s --check-prefixes=CHECK-EXT
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK-NOEXT-NO: OpCapability ArithmeticFenceEXT
8+
; CHECK-NOEXT-NO: OpExtension "SPV_EXT_arithmetic_fence"
9+
; CHECK-NOEXT: OpFunction
10+
; CHECK-NOEXT: OpFMul
11+
; CHECK-NOEXT: OpFAdd
12+
; CHECK-NOEXT-NO: OpArithmeticFenceEXT
13+
; CHECK-NOEXT: OpFunction
14+
; CHECK-NOEXT-NO: OpArithmeticFenceEXT
15+
; CHECK-NOEXT: OpFunction
16+
; CHECK-NOEXT-NO: OpArithmeticFenceEXT
17+
18+
; CHECK-EXT: OpCapability ArithmeticFenceEXT
19+
; CHECK-EXT: OpExtension "SPV_EXT_arithmetic_fence"
20+
; CHECK-EXT: OpFunction
21+
; CHECK-EXT: [[R1:%.*]] = OpFMul [[I32Ty:%.*]] %[[#]] %[[#]]
22+
; CHECK-EXT: [[R2:%.*]] = OpArithmeticFenceEXT [[I32Ty]] [[R1]]
23+
; CHECK-EXT: %[[#]] = OpFAdd [[I32Ty]] [[R2]] %[[#]]
24+
; CHECK-EXT: OpFunction
25+
; CHECK-EXT: [[R3:%.*]] = OpFAdd [[I64Ty:%.*]] [[A1:%.*]] [[A1]]
26+
; CHECK-EXT: [[R4:%.*]] = OpArithmeticFenceEXT [[I64Ty]] [[R3]]
27+
; CHECK-EXT: [[R5:%.*]] = OpFAdd [[I64Ty]] [[A1]] [[A1]]
28+
; CHECK-EXT: %[[#]] = OpFAdd [[I64Ty]] [[R4]] [[R5]]
29+
; CHECK-EXT: OpFunction
30+
; CHECK-EXT: [[R6:%.*]] = OpFAdd [[I32VecTy:%.*]] [[A2:%.*]] [[A2]]
31+
; CHECK-EXT: [[R7:%.*]] = OpArithmeticFenceEXT [[I32VecTy]] [[R6]]
32+
; CHECK-EXT: [[R8:%.*]] = OpFAdd [[I32VecTy]] [[A2]] [[A2]]
33+
; CHECK-EXT: %[[#]] = OpFAdd [[I32VecTy]] [[R7]] [[R8]]
34+
35+
define float @f1(float %a, float %b, float %c) {
36+
%mul = fmul fast float %b, %a
37+
%tmp = call float @llvm.arithmetic.fence.f32(float %mul)
38+
%add = fadd fast float %tmp, %c
39+
ret float %add
40+
}
41+
42+
define double @f2(double %a) {
43+
%1 = fadd fast double %a, %a
44+
%t = call double @llvm.arithmetic.fence.f64(double %1)
45+
%2 = fadd fast double %a, %a
46+
%3 = fadd fast double %t, %2
47+
ret double %3
48+
}
49+
50+
define <2 x float> @f3(<2 x float> %a) {
51+
%1 = fadd fast <2 x float> %a, %a
52+
%t = call <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float> %1)
53+
%2 = fadd fast <2 x float> %a, %a
54+
%3 = fadd fast <2 x float> %t, %2
55+
ret <2 x float> %3
56+
}
57+
58+
declare float @llvm.arithmetic.fence.f32(float)
59+
declare double @llvm.arithmetic.fence.f64(double)
60+
declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)

0 commit comments

Comments
 (0)