Skip to content

Commit 23c72e9

Browse files
[SPIR-V] Allow non-const arguments in a Group builtin that requires a boolean argument (#102902)
This PR resolves a TODO in `generateGroupInst()` (`lib/Target/SPIRV/SPIRVBuiltins.cpp`) and Issues #97311 and #97312 by implementing support for non-const arguments in a Group builtin that requires a boolean argument.
1 parent 3825a7c commit 23c72e9

File tree

2 files changed

+93
-13
lines changed

2 files changed

+93
-13
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,16 +1091,30 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
10911091

10921092
Register Arg0;
10931093
if (GroupBuiltin->HasBoolArg) {
1094-
Register ConstRegister = Call->Arguments[0];
1095-
auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
1096-
(void)ArgInstruction;
1097-
// TODO: support non-constant bool values.
1098-
assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
1099-
"Only constant bool value args are supported");
1100-
if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() !=
1101-
SPIRV::OpTypeBool)
1102-
Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder,
1103-
GR->getOrCreateSPIRVBoolType(MIRBuilder));
1094+
SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
1095+
Register BoolReg = Call->Arguments[0];
1096+
SPIRVType *BoolRegType = GR->getSPIRVTypeForVReg(BoolReg);
1097+
if (!BoolRegType)
1098+
report_fatal_error("Can't find a register's type definition");
1099+
MachineInstr *ArgInstruction = getDefInstrMaybeConstant(BoolReg, MRI);
1100+
if (ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT) {
1101+
if (BoolRegType->getOpcode() != SPIRV::OpTypeBool)
1102+
Arg0 = GR->buildConstantInt(getIConstVal(BoolReg, MRI), MIRBuilder,
1103+
BoolType);
1104+
} else {
1105+
if (BoolRegType->getOpcode() == SPIRV::OpTypeInt) {
1106+
Arg0 = MRI->createGenericVirtualRegister(LLT::scalar(1));
1107+
MRI->setRegClass(Arg0, &SPIRV::IDRegClass);
1108+
GR->assignSPIRVTypeToVReg(BoolType, Arg0, MIRBuilder.getMF());
1109+
MIRBuilder.buildICmp(CmpInst::ICMP_NE, Arg0, BoolReg,
1110+
GR->buildConstantInt(0, MIRBuilder, BoolRegType));
1111+
insertAssignInstr(Arg0, nullptr, BoolType, GR, MIRBuilder,
1112+
MIRBuilder.getMF().getRegInfo());
1113+
} else if (BoolRegType->getOpcode() != SPIRV::OpTypeBool) {
1114+
report_fatal_error("Expect a boolean argument");
1115+
}
1116+
// if BoolReg is a boolean register, we don't need to do anything
1117+
}
11041118
}
11051119

11061120
Register GroupResultRegister = Call->ReturnRegister;

llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
99
; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
1010
; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
11+
12+
; CHECK-SPIRV: OpFunction
1113
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
1214
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
1315
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
1416
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
15-
1617
define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
1718
entry:
1819
%call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
@@ -22,8 +23,73 @@ entry:
2223
ret void
2324
}
2425

25-
declare spir_func i32 @_Z14work_group_alli(i32)
26-
declare spir_func i32 @_Z14work_group_anyi(i32)
26+
; CHECK-SPIRV: OpFunction
27+
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#]]
28+
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#]]
29+
define spir_kernel void @test_nonconst_any(ptr addrspace(1) %input, ptr addrspace(1) %output) #0 !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 !spirv.ParameterDecorations !11 {
30+
entry:
31+
%r0 = call spir_func i64 @_Z13get_global_idj(i32 0)
32+
%r1 = insertelement <3 x i64> undef, i64 %r0, i32 0
33+
%r2 = call spir_func i64 @_Z13get_global_idj(i32 1)
34+
%r3 = insertelement <3 x i64> %r1, i64 %r2, i32 1
35+
%r4 = call spir_func i64 @_Z13get_global_idj(i32 2)
36+
%r5 = insertelement <3 x i64> %r3, i64 %r4, i32 2
37+
%call = extractelement <3 x i64> %r5, i32 0
38+
%conv = trunc i64 %call to i32
39+
%idxprom = sext i32 %conv to i64
40+
%arrayidx = getelementptr inbounds float, ptr addrspace(1) %input, i64 %idxprom
41+
%r6 = load float, ptr addrspace(1) %arrayidx, align 4
42+
%add = add nsw i32 %conv, 1
43+
%idxprom1 = sext i32 %add to i64
44+
%arrayidx2 = getelementptr inbounds float, ptr addrspace(1) %input, i64 %idxprom1
45+
%r7 = load float, ptr addrspace(1) %arrayidx2, align 4
46+
%cmp = fcmp ogt float %r6, %r7
47+
%conv3 = select i1 %cmp, i32 1, i32 0
48+
%r8 = icmp ne i32 %conv3, 0
49+
%r9 = zext i1 %r8 to i32
50+
%r10 = call spir_func i32 @_Z14work_group_anyi(i32 %r9)
51+
%call41 = icmp ne i32 %r10, 0
52+
%call4 = select i1 %call41, i32 1, i32 0
53+
%idxprom5 = sext i32 %conv to i64
54+
%arrayidx6 = getelementptr inbounds i32, ptr addrspace(1) %output, i64 %idxprom5
55+
store i32 %call4, ptr addrspace(1) %arrayidx6, align 4
56+
%r11 = call spir_func i32 @_Z14work_group_alli(i32 %r9)
57+
%call42 = icmp ne i32 %r11, 0
58+
%call5 = select i1 %call42, i32 1, i32 0
59+
store i32 %call5, ptr addrspace(1) %arrayidx6, align 4
60+
ret void
61+
}
62+
63+
declare spir_func i64 @_Z13get_global_idj(i32) #1
64+
65+
declare spir_func i32 @_Z14work_group_alli(i32) #2
66+
declare spir_func i32 @_Z14work_group_anyi(i32) #2
2767

2868
declare spir_func i1 @__spirv_GroupAll(i32, i1)
2969
declare spir_func i1 @__spirv_GroupAny(i32, i1)
70+
71+
attributes #0 = { nounwind }
72+
attributes #1 = { nounwind willreturn memory(none) }
73+
attributes #2 = { convergent nounwind }
74+
75+
!spirv.MemoryModel = !{!0}
76+
!opencl.enable.FP_CONTRACT = !{}
77+
!spirv.Source = !{!1}
78+
!opencl.spir.version = !{!2}
79+
!opencl.ocl.version = !{!3}
80+
!opencl.used.extensions = !{!4}
81+
!opencl.used.optional.core.features = !{!5}
82+
!spirv.Generator = !{!6}
83+
84+
!0 = !{i32 2, i32 2}
85+
!1 = !{i32 3, i32 300000}
86+
!2 = !{i32 2, i32 0}
87+
!3 = !{i32 3, i32 0}
88+
!4 = !{!"cl_khr_subgroups"}
89+
!5 = !{}
90+
!6 = !{i16 6, i16 14}
91+
!7 = !{i32 1, i32 1}
92+
!8 = !{!"none", !"none"}
93+
!9 = !{!"float*", !"int*"}
94+
!10 = !{!"", !""}
95+
!11 = !{!5, !5}

0 commit comments

Comments
 (0)