Skip to content

Commit 4baea8b

Browse files
[SPIR-V] Implement insertion of 'Group and Subgroup Instructions' using builtin functions (#95176)
This PR adds builtin functions to insert instructions from 'Group and Subgroup' section of the SPIR-V Specification. Corresponding tests are updated, `spirv-val` run is added where it was missed.
1 parent 9890f94 commit 4baea8b

File tree

4 files changed

+101
-6
lines changed

4 files changed

+101
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
10151015
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
10161016
const SPIRV::GroupBuiltin *GroupBuiltin =
10171017
SPIRV::lookupGroupBuiltin(Builtin->Name);
1018+
10181019
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1020+
if (Call->isSpirvOp()) {
1021+
if (GroupBuiltin->NoGroupOperation)
1022+
return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
1023+
GR->getSPIRVTypeID(Call->ReturnType));
1024+
1025+
// Group Operation is a literal
1026+
Register GroupOpReg = Call->Arguments[1];
1027+
const MachineInstr *MI = getDefInstrMaybeConstant(GroupOpReg, MRI);
1028+
if (!MI || MI->getOpcode() != TargetOpcode::G_CONSTANT)
1029+
report_fatal_error(
1030+
"Group Operation parameter must be an integer constant");
1031+
uint64_t GrpOp = MI->getOperand(1).getCImm()->getValue().getZExtValue();
1032+
Register ScopeReg = Call->Arguments[0];
1033+
if (!MRI->getRegClassOrNull(ScopeReg))
1034+
MRI->setRegClass(ScopeReg, &SPIRV::IDRegClass);
1035+
Register ValueReg = Call->Arguments[2];
1036+
if (!MRI->getRegClassOrNull(ValueReg))
1037+
MRI->setRegClass(ValueReg, &SPIRV::IDRegClass);
1038+
MIRBuilder.buildInstr(GroupBuiltin->Opcode)
1039+
.addDef(Call->ReturnRegister)
1040+
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
1041+
.addUse(ScopeReg)
1042+
.addImm(GrpOp)
1043+
.addUse(ValueReg);
1044+
return true;
1045+
}
1046+
10191047
Register Arg0;
10201048
if (GroupBuiltin->HasBoolArg) {
10211049
Register ConstRegister = Call->Arguments[0];

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,17 @@ multiclass DemangledGroupBuiltin<string name, int level /* OnlyWork/OnlySub/...
694694
}
695695
}
696696

697+
multiclass DemangledGroupBuiltinWrapper<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
698+
def : DemangledBuiltin<name, OpenCL_std, Group, minNumArgs, maxNumArgs>;
699+
def : GroupBuiltin<name, operation>;
700+
}
701+
697702
defm : DemangledGroupBuiltin<"group_all", WorkOrSub, OpGroupAll>;
703+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAll", 2, 2, OpGroupAll>;
698704
defm : DemangledGroupBuiltin<"group_any", WorkOrSub, OpGroupAny>;
705+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupAny", 2, 2, OpGroupAny>;
699706
defm : DemangledGroupBuiltin<"group_broadcast", WorkOrSub, OpGroupBroadcast>;
707+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupBroadcast", 3, 3, OpGroupBroadcast>;
700708
defm : DemangledGroupBuiltin<"group_non_uniform_broadcast", OnlySub, OpGroupNonUniformBroadcast>;
701709
defm : DemangledGroupBuiltin<"group_broadcast_first", OnlySub, OpGroupNonUniformBroadcastFirst>;
702710

@@ -731,41 +739,49 @@ defm : DemangledGroupBuiltin<"group_scan_inclusive_adds", WorkOrSub, OpGroupIAdd
731739
defm : DemangledGroupBuiltin<"group_reduce_addu", WorkOrSub, OpGroupIAdd>;
732740
defm : DemangledGroupBuiltin<"group_scan_exclusive_addu", WorkOrSub, OpGroupIAdd>;
733741
defm : DemangledGroupBuiltin<"group_scan_inclusive_addu", WorkOrSub, OpGroupIAdd>;
742+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupIAdd", 3, 3, OpGroupIAdd>;
734743

735744
defm : DemangledGroupBuiltin<"group_fadd", WorkOrSub, OpGroupFAdd>;
736745
defm : DemangledGroupBuiltin<"group_reduce_addf", WorkOrSub, OpGroupFAdd>;
737746
defm : DemangledGroupBuiltin<"group_scan_exclusive_addf", WorkOrSub, OpGroupFAdd>;
738747
defm : DemangledGroupBuiltin<"group_scan_inclusive_addf", WorkOrSub, OpGroupFAdd>;
748+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFAdd", 3, 3, OpGroupFAdd>;
739749

740750
defm : DemangledGroupBuiltin<"group_fmin", WorkOrSub, OpGroupFMin>;
741751
defm : DemangledGroupBuiltin<"group_reduce_minf", WorkOrSub, OpGroupFMin>;
742752
defm : DemangledGroupBuiltin<"group_scan_exclusive_minf", WorkOrSub, OpGroupFMin>;
743753
defm : DemangledGroupBuiltin<"group_scan_inclusive_minf", WorkOrSub, OpGroupFMin>;
754+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMin", 3, 3, OpGroupFMin>;
744755

745756
defm : DemangledGroupBuiltin<"group_umin", WorkOrSub, OpGroupUMin>;
746757
defm : DemangledGroupBuiltin<"group_reduce_minu", WorkOrSub, OpGroupUMin>;
747758
defm : DemangledGroupBuiltin<"group_scan_exclusive_minu", WorkOrSub, OpGroupUMin>;
748759
defm : DemangledGroupBuiltin<"group_scan_inclusive_minu", WorkOrSub, OpGroupUMin>;
760+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMin", 3, 3, OpGroupUMin>;
749761

750762
defm : DemangledGroupBuiltin<"group_smin", WorkOrSub, OpGroupSMin>;
751763
defm : DemangledGroupBuiltin<"group_reduce_mins", WorkOrSub, OpGroupSMin>;
752764
defm : DemangledGroupBuiltin<"group_scan_exclusive_mins", WorkOrSub, OpGroupSMin>;
753765
defm : DemangledGroupBuiltin<"group_scan_inclusive_mins", WorkOrSub, OpGroupSMin>;
766+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMin", 3, 3, OpGroupSMin>;
754767

755768
defm : DemangledGroupBuiltin<"group_fmax", WorkOrSub, OpGroupFMax>;
756769
defm : DemangledGroupBuiltin<"group_reduce_maxf", WorkOrSub, OpGroupFMax>;
757770
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxf", WorkOrSub, OpGroupFMax>;
758771
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxf", WorkOrSub, OpGroupFMax>;
772+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupFMax", 3, 3, OpGroupFMax>;
759773

760774
defm : DemangledGroupBuiltin<"group_umax", WorkOrSub, OpGroupUMax>;
761775
defm : DemangledGroupBuiltin<"group_reduce_maxu", WorkOrSub, OpGroupUMax>;
762776
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxu", WorkOrSub, OpGroupUMax>;
763777
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxu", WorkOrSub, OpGroupUMax>;
778+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupUMax", 3, 3, OpGroupUMax>;
764779

765780
defm : DemangledGroupBuiltin<"group_smax", WorkOrSub, OpGroupSMax>;
766781
defm : DemangledGroupBuiltin<"group_reduce_maxs", WorkOrSub, OpGroupSMax>;
767782
defm : DemangledGroupBuiltin<"group_scan_exclusive_maxs", WorkOrSub, OpGroupSMax>;
768783
defm : DemangledGroupBuiltin<"group_scan_inclusive_maxs", WorkOrSub, OpGroupSMax>;
784+
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupSMax", 3, 3, OpGroupSMax>;
769785

770786
// cl_khr_subgroup_non_uniform_arithmetic
771787
defm : DemangledGroupBuiltin<"group_non_uniform_iadd", WorkOrSub, OpGroupNonUniformIAdd>;
Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,29 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
14
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
26

37
; CHECK-SPIRV: OpCapability Groups
4-
; CHECK-SPIRV: %[[#BoolTypeID:]] = OpTypeBool
5-
; CHECK-SPIRV: %[[#ConstID:]] = OpConstantTrue %[[#BoolTypeID]]
6-
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
7-
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#ConstID]]
8+
; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
9+
; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
10+
; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
11+
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
12+
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
13+
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
14+
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
815

916
define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
1017
entry:
1118
%call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
1219
%call1 = tail call spir_func i32 @_Z14work_group_anyi(i32 5)
20+
%call3 = tail call spir_func i32 @__spirv_GroupAll(i32 0, i1 1)
21+
%call4 = tail call spir_func i32 @__spirv_GroupAny(i32 0, i1 0)
1322
ret void
1423
}
1524

1625
declare spir_func i32 @_Z14work_group_alli(i32)
17-
1826
declare spir_func i32 @_Z14work_group_anyi(i32)
27+
28+
declare spir_func i1 @__spirv_GroupAll(i32, i1)
29+
declare spir_func i1 @__spirv_GroupAny(i32, i1)

llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
14
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
25
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
36

47
; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
58
; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
9+
; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
610
; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
711
; CHECK-SPIRV-DAG: %[[#ScopeSubgroup:]] = OpConstant %[[#int]] 3
812

@@ -247,7 +251,8 @@ entry:
247251
declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_addr
248252

249253
; CHECK-SPIRV: OpFunction
250-
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]]
254+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
255+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
251256
; CHECK-SPIRV: OpFunctionEnd
252257

253258
;; kernel void testWorkGroupBroadcast(uint a, global size_t *id, global int *res) {
@@ -259,7 +264,42 @@ entry:
259264
%0 = load i32, i32 addrspace(1)* %id, align 4
260265
%call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
261266
store i32 %call, i32 addrspace(1)* %res, align 4
267+
%call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
262268
ret void
263269
}
264270

265271
declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
272+
declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
273+
274+
; CHECK-SPIRV: OpFunction
275+
; CHECK-SPIRV: %[[#]] = OpGroupFAdd %[[#float]] %[[#ScopeCrossWorkgroup]] Reduce %[[#FValue:]]
276+
; CHECK-SPIRV: %[[#]] = OpGroupFMin %[[#float]] %[[#ScopeWorkgroup]] InclusiveScan %[[#FValue]]
277+
; CHECK-SPIRV: %[[#]] = OpGroupFMax %[[#float]] %[[#ScopeSubgroup]] ExclusiveScan %[[#FValue]]
278+
; CHECK-SPIRV: %[[#]] = OpGroupIAdd %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue:]]
279+
; CHECK-SPIRV: %[[#]] = OpGroupUMin %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
280+
; CHECK-SPIRV: %[[#]] = OpGroupSMin %[[#int]] %[[#ScopeSubgroup]] ExclusiveScan %[[#IValue]]
281+
; CHECK-SPIRV: %[[#]] = OpGroupUMax %[[#int]] %[[#ScopeCrossWorkgroup]] Reduce %[[#IValue]]
282+
; CHECK-SPIRV: %[[#]] = OpGroupSMax %[[#int]] %[[#ScopeWorkgroup]] InclusiveScan %[[#IValue]]
283+
; CHECK-SPIRV: OpFunctionEnd
284+
285+
define spir_kernel void @foo(float %a, i32 %b) {
286+
entry:
287+
%f1 = call spir_func float @__spirv_GroupFAdd(i32 0, i32 0, float %a)
288+
%f2 = call spir_func float @__spirv_GroupFMin(i32 2, i32 1, float %a)
289+
%f3 = call spir_func float @__spirv_GroupFMax(i32 3, i32 2, float %a)
290+
%i1 = call spir_func i32 @__spirv_GroupIAdd(i32 0, i32 0, i32 %b)
291+
%i2 = call spir_func i32 @__spirv_GroupUMin(i32 2, i32 1, i32 %b)
292+
%i3 = call spir_func i32 @__spirv_GroupSMin(i32 3, i32 2, i32 %b)
293+
%i4 = call spir_func i32 @__spirv_GroupUMax(i32 0, i32 0, i32 %b)
294+
%i5 = call spir_func i32 @__spirv_GroupSMax(i32 2, i32 1, i32 %b)
295+
ret void
296+
}
297+
298+
declare spir_func float @__spirv_GroupFAdd(i32, i32, float)
299+
declare spir_func float @__spirv_GroupFMin(i32, i32, float)
300+
declare spir_func float @__spirv_GroupFMax(i32, i32, float)
301+
declare spir_func i32 @__spirv_GroupIAdd(i32, i32, i32)
302+
declare spir_func i32 @__spirv_GroupUMin(i32, i32, i32)
303+
declare spir_func i32 @__spirv_GroupSMin(i32, i32, i32)
304+
declare spir_func i32 @__spirv_GroupUMax(i32, i32, i32)
305+
declare spir_func i32 @__spirv_GroupSMax(i32, i32, i32)

0 commit comments

Comments
 (0)