Skip to content

Commit 2fc7a72

Browse files
[SPIR-V] Add implementation of the non-const G_BUILD_VECTOR and fix emission of the OpGroupBroadcast instruction (#103050)
This PR addresses a TODO in lib/Target/SPIRV/SPIRVInstructionSelector.cpp by adding implementation of the non-const G_BUILD_VECTOR, and fix emission of the OpGroupBroadcast instruction for the case when the `..._group_broadcast` builtin has more than one `local_id` argument and `OpGroupBroadcast` requires a newly constructed vector with 2 or 3 components instead of originally passed series of `local_id` arguments. This PR may resolve #97310 if the reason for the reported fail is an incorrectly generated OpGroupBroadcast instruction that was definitely a case. Existing test is hardened and a new test is added to cover this special case of the OpGroupBroadcast instruction emission.
1 parent c4206f1 commit 2fc7a72

File tree

5 files changed

+248
-41
lines changed

5 files changed

+248
-41
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
11351135
: SPIRV::Scope::Workgroup;
11361136
Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
11371137

1138+
Register VecReg;
1139+
if (GroupBuiltin->Opcode == SPIRV::OpGroupBroadcast &&
1140+
Call->Arguments.size() > 2) {
1141+
// For OpGroupBroadcast "LocalId must be an integer datatype. It must be a
1142+
// scalar, a vector with 2 components, or a vector with 3 components.",
1143+
// meaning that we must create a vector from the function arguments if
1144+
// it's a work_group_broadcast(val, local_id_x, local_id_y) or
1145+
// work_group_broadcast(val, local_id_x, local_id_y, local_id_z) call.
1146+
Register ElemReg = Call->Arguments[1];
1147+
SPIRVType *ElemType = GR->getSPIRVTypeForVReg(ElemReg);
1148+
if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeInt)
1149+
report_fatal_error("Expect an integer <LocalId> argument");
1150+
unsigned VecLen = Call->Arguments.size() - 1;
1151+
VecReg = MRI->createGenericVirtualRegister(
1152+
LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
1153+
MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
1154+
SPIRVType *VecType =
1155+
GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder);
1156+
GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
1157+
auto MIB =
1158+
MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
1159+
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1160+
MIB.addUse(Call->Arguments[i]);
1161+
MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
1162+
}
1163+
insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
1164+
MIRBuilder.getMF().getRegInfo());
1165+
}
1166+
11381167
// Build work/sub group instruction.
11391168
auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
11401169
.addDef(GroupResultRegister)
@@ -1146,10 +1175,13 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
11461175
if (Call->Arguments.size() > 0) {
11471176
MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
11481177
MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
1149-
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1150-
MIB.addUse(Call->Arguments[i]);
1151-
MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
1152-
}
1178+
if (VecReg.isValid())
1179+
MIB.addUse(VecReg);
1180+
else
1181+
for (unsigned i = 1; i < Call->Arguments.size(); i++) {
1182+
MIB.addUse(Call->Arguments[i]);
1183+
MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
1184+
}
11531185
}
11541186

11551187
// Build select instruction.

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
159159
bool selectBitreverse(Register ResVReg, const SPIRVType *ResType,
160160
MachineInstr &I) const;
161161

162-
bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
162+
bool selectBuildVector(Register ResVReg, const SPIRVType *ResType,
163163
MachineInstr &I) const;
164164
bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
165165
MachineInstr &I) const;
@@ -411,7 +411,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
411411
return selectBitreverse(ResVReg, ResType, I);
412412

413413
case TargetOpcode::G_BUILD_VECTOR:
414-
return selectConstVector(ResVReg, ResType, I);
414+
return selectBuildVector(ResVReg, ResType, I);
415415
case TargetOpcode::G_SPLAT_VECTOR:
416416
return selectSplatVector(ResVReg, ResType, I);
417417

@@ -1497,35 +1497,6 @@ bool SPIRVInstructionSelector::selectFreeze(Register ResVReg,
14971497
return false;
14981498
}
14991499

1500-
bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
1501-
const SPIRVType *ResType,
1502-
MachineInstr &I) const {
1503-
// TODO: only const case is supported for now.
1504-
assert(std::all_of(
1505-
I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) {
1506-
if (MO.isDef())
1507-
return true;
1508-
if (!MO.isReg())
1509-
return false;
1510-
SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg());
1511-
assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
1512-
ConstTy->getOperand(1).isReg());
1513-
Register ConstReg = ConstTy->getOperand(1).getReg();
1514-
const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
1515-
assert(Const);
1516-
return (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
1517-
Const->getOpcode() == TargetOpcode::G_FCONSTANT);
1518-
}));
1519-
1520-
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
1521-
TII.get(SPIRV::OpConstantComposite))
1522-
.addDef(ResVReg)
1523-
.addUse(GR.getSPIRVTypeID(ResType));
1524-
for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
1525-
MIB.addUse(I.getOperand(i).getReg());
1526-
return MIB.constrainAllUses(TII, TRI, RBI);
1527-
}
1528-
15291500
static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
15301501
const SPIRVType *ResType) {
15311502
Register OpReg = ResType->getOperand(2).getReg();
@@ -1591,6 +1562,40 @@ static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
15911562
return false;
15921563
}
15931564

1565+
bool SPIRVInstructionSelector::selectBuildVector(Register ResVReg,
1566+
const SPIRVType *ResType,
1567+
MachineInstr &I) const {
1568+
unsigned N = 0;
1569+
if (ResType->getOpcode() == SPIRV::OpTypeVector)
1570+
N = GR.getScalarOrVectorComponentCount(ResType);
1571+
else if (ResType->getOpcode() == SPIRV::OpTypeArray)
1572+
N = getArrayComponentCount(MRI, ResType);
1573+
else
1574+
report_fatal_error("Cannot select G_BUILD_VECTOR with a non-vector result");
1575+
if (I.getNumExplicitOperands() - I.getNumExplicitDefs() != N)
1576+
report_fatal_error("G_BUILD_VECTOR and the result type are inconsistent");
1577+
1578+
// check if we may construct a constant vector
1579+
bool IsConst = true;
1580+
for (unsigned i = I.getNumExplicitDefs();
1581+
i < I.getNumExplicitOperands() && IsConst; ++i)
1582+
if (!isConstReg(MRI, I.getOperand(i).getReg()))
1583+
IsConst = false;
1584+
1585+
if (!IsConst && N < 2)
1586+
report_fatal_error(
1587+
"There must be at least two constituent operands in a vector");
1588+
1589+
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
1590+
TII.get(IsConst ? SPIRV::OpConstantComposite
1591+
: SPIRV::OpCompositeConstruct))
1592+
.addDef(ResVReg)
1593+
.addUse(GR.getSPIRVTypeID(ResType));
1594+
for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
1595+
MIB.addUse(I.getOperand(i).getReg());
1596+
return MIB.constrainAllUses(TII, TRI, RBI);
1597+
}
1598+
15941599
bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
15951600
const SPIRVType *ResType,
15961601
MachineInstr &I) const {

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,23 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
530530
MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
531531
assert(ElemMI);
532532

533-
if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
533+
if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
534534
ElemTy = ElemMI->getOperand(1).getCImm()->getType();
535-
else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
535+
} else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
536536
ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
537+
} else {
538+
// There may be a case when we already know Reg's type.
539+
MachineInstr *NextMI = MI.getNextNode();
540+
if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
541+
NextMI->getOperand(1).getReg() != Reg)
542+
llvm_unreachable("Unexpected opcode");
543+
}
544+
if (ElemTy)
545+
Ty = VectorType::get(
546+
ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
547+
false);
537548
else
538-
llvm_unreachable("Unexpected opcode");
539-
unsigned NumElts =
540-
MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
541-
Ty = VectorType::get(ElemTy, NumElts, false);
549+
NeedAssignType = false;
542550
}
543551
if (NeedAssignType)
544552
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK-SPIRV: OpCapability Groups
8+
; CHECK-SPIRV-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
9+
; CHECK-SPIRV-DAG: %[[#Int64Ty:]] = OpTypeInt 64 0
10+
; CHECK-SPIRV-DAG: %[[#Float32Ty:]] = OpTypeFloat 32
11+
; CHECK-SPIRV-DAG: %[[#Vec2Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 2
12+
; CHECK-SPIRV-DAG: %[[#Vec3Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 3
13+
; CHECK-SPIRV-DAG: %[[#Vec2Int64Ty:]] = OpTypeVector %[[#Int64Ty]] 2
14+
; CHECK-SPIRV-DAG: %[[#C2:]] = OpConstant %[[#Int32Ty]] 2
15+
16+
; CHECK-SPIRV: OpFunction
17+
; CHECK-SPIRV: %[[#Val:]] = OpFunctionParameter %[[#Int32Ty]]
18+
; CHECK-SPIRV: %[[#X:]] = OpFunctionParameter %[[#Int32Ty]]
19+
; CHECK-SPIRV: %[[#Y:]] = OpFunctionParameter %[[#Int32Ty]]
20+
; CHECK-SPIRV: %[[#Z:]] = OpFunctionParameter %[[#Int32Ty]]
21+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#X]]
22+
; CHECK-SPIRV: %[[#XY:]] = OpCompositeConstruct %[[#Vec2Int32Ty]] %[[#X]] %[[#Y]]
23+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XY]]
24+
; CHECK-SPIRV: %[[#XYZ:]] = OpCompositeConstruct %[[#Vec3Int32Ty]] %[[#X]] %[[#Y]] %[[#Z]]
25+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XYZ]]
26+
define spir_kernel void @test_broadcast_xyz(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z) {
27+
entry:
28+
%call1 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x)
29+
%call2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y)
30+
%call3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z)
31+
ret void
32+
}
33+
34+
declare spir_func i32 @_Z20work_group_broadcastjj(i32, i32)
35+
declare spir_func i32 @_Z20work_group_broadcastjjj(i32, i32, i32)
36+
declare spir_func i32 @_Z20work_group_broadcastjjjj(i32, i32, i32, i32)
37+
38+
; CHECK-SPIRV: OpFunction
39+
; CHECK-SPIRV: OpInBoundsPtrAccessChain
40+
; CHECK-SPIRV: %[[#LoadedVal:]] = OpLoad %[[#Float32Ty]] %[[#]]
41+
; CHECK-SPIRV: %[[#IdX:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 0
42+
; CHECK-SPIRV: %[[#IdY:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 1
43+
; CHECK-SPIRV: %[[#LocIdsVec:]] = OpCompositeConstruct %[[#Vec2Int64Ty]] %[[#IdX]] %[[#IdY]]
44+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Float32Ty]] %[[#C2]] %[[#LoadedVal]] %[[#LocIdsVec]]
45+
define spir_kernel void @test_wg_broadcast_2D(ptr addrspace(1) %input, ptr addrspace(1) %output) #0 !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 !spirv.ParameterDecorations !11 {
46+
entry:
47+
%0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
48+
%1 = insertelement <3 x i64> undef, i64 %0, i32 0
49+
%2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
50+
%3 = insertelement <3 x i64> %1, i64 %2, i32 1
51+
%4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
52+
%5 = insertelement <3 x i64> %3, i64 %4, i32 2
53+
%call = extractelement <3 x i64> %5, i32 0
54+
%6 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
55+
%7 = insertelement <3 x i64> undef, i64 %6, i32 0
56+
%8 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
57+
%9 = insertelement <3 x i64> %7, i64 %8, i32 1
58+
%10 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
59+
%11 = insertelement <3 x i64> %9, i64 %10, i32 2
60+
%call1 = extractelement <3 x i64> %11, i32 1
61+
%12 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
62+
%13 = insertelement <3 x i64> undef, i64 %12, i32 0
63+
%14 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
64+
%15 = insertelement <3 x i64> %13, i64 %14, i32 1
65+
%16 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
66+
%17 = insertelement <3 x i64> %15, i64 %16, i32 2
67+
%call2 = extractelement <3 x i64> %17, i32 0
68+
%18 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
69+
%19 = insertelement <3 x i64> undef, i64 %18, i32 0
70+
%20 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
71+
%21 = insertelement <3 x i64> %19, i64 %20, i32 1
72+
%22 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
73+
%23 = insertelement <3 x i64> %21, i64 %22, i32 2
74+
%call3 = extractelement <3 x i64> %23, i32 0
75+
%rem = urem i64 %call2, %call3
76+
%24 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
77+
%25 = insertelement <3 x i64> undef, i64 %24, i32 0
78+
%26 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
79+
%27 = insertelement <3 x i64> %25, i64 %26, i32 1
80+
%28 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
81+
%29 = insertelement <3 x i64> %27, i64 %28, i32 2
82+
%call4 = extractelement <3 x i64> %29, i32 1
83+
%30 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
84+
%31 = insertelement <3 x i64> undef, i64 %30, i32 0
85+
%32 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
86+
%33 = insertelement <3 x i64> %31, i64 %32, i32 1
87+
%34 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
88+
%35 = insertelement <3 x i64> %33, i64 %34, i32 2
89+
%call5 = extractelement <3 x i64> %35, i32 1
90+
%rem6 = urem i64 %call4, %call5
91+
%36 = call spir_func i64 @_Z15get_global_sizej(i32 0) #1
92+
%37 = insertelement <3 x i64> undef, i64 %36, i32 0
93+
%38 = call spir_func i64 @_Z15get_global_sizej(i32 1) #1
94+
%39 = insertelement <3 x i64> %37, i64 %38, i32 1
95+
%40 = call spir_func i64 @_Z15get_global_sizej(i32 2) #1
96+
%41 = insertelement <3 x i64> %39, i64 %40, i32 2
97+
%call7 = extractelement <3 x i64> %41, i32 0
98+
%mul = mul i64 %call1, %call7
99+
%add = add i64 %mul, %call
100+
%arrayidx = getelementptr inbounds float, ptr addrspace(1) %input, i64 %add
101+
%42 = load float, ptr addrspace(1) %arrayidx, align 4
102+
%.splatinsert = insertelement <2 x i64> undef, i64 %rem, i32 0
103+
%.splat = shufflevector <2 x i64> %.splatinsert, <2 x i64> undef, <2 x i32> zeroinitializer
104+
%43 = insertelement <2 x i64> %.splat, i64 %rem6, i32 1
105+
%44 = extractelement <2 x i64> %43, i32 0
106+
%45 = extractelement <2 x i64> %43, i32 1
107+
%call8 = call spir_func float @_Z20work_group_broadcastfmm(float %42, i64 %44, i64 %45) #2
108+
%arrayidx9 = getelementptr inbounds float, ptr addrspace(1) %output, i64 %add
109+
store float %call8, ptr addrspace(1) %arrayidx9, align 4
110+
ret void
111+
}
112+
113+
; Function Attrs: nounwind willreturn memory(none)
114+
declare spir_func i64 @_Z13get_global_idj(i32) #1
115+
116+
; Function Attrs: nounwind willreturn memory(none)
117+
declare spir_func i64 @_Z12get_group_idj(i32) #1
118+
119+
; Function Attrs: nounwind willreturn memory(none)
120+
declare spir_func i64 @_Z14get_local_sizej(i32) #1
121+
122+
; Function Attrs: nounwind willreturn memory(none)
123+
declare spir_func i64 @_Z15get_global_sizej(i32) #1
124+
125+
; Function Attrs: convergent nounwind
126+
declare spir_func float @_Z20work_group_broadcastfmm(float, i64, i64) #2
127+
128+
attributes #0 = { nounwind }
129+
attributes #1 = { nounwind willreturn memory(none) }
130+
attributes #2 = { convergent nounwind }
131+
132+
!spirv.MemoryModel = !{!0}
133+
!opencl.enable.FP_CONTRACT = !{}
134+
!spirv.Source = !{!1}
135+
!opencl.spir.version = !{!2}
136+
!opencl.ocl.version = !{!3}
137+
!opencl.used.extensions = !{!4}
138+
!opencl.used.optional.core.features = !{!5}
139+
!spirv.Generator = !{!6}
140+
141+
!0 = !{i32 2, i32 2}
142+
!1 = !{i32 3, i32 300000}
143+
!2 = !{i32 2, i32 0}
144+
!3 = !{i32 3, i32 0}
145+
!4 = !{!"cl_khr_subgroups"}
146+
!5 = !{}
147+
!6 = !{i16 6, i16 14}
148+
!7 = !{i32 1, i32 1}
149+
!8 = !{!"none", !"none"}
150+
!9 = !{!"float*", !"float*"}
151+
!10 = !{!"", !""}
152+
!11 = !{!5, !5}

llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
66

77
; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
8+
; CHECK-SPIRV-DAG: %[[#intv2:]] = OpTypeVector %[[#int]] 2
9+
; CHECK-SPIRV-DAG: %[[#intv3:]] = OpTypeVector %[[#int]] 3
810
; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
911
; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
1012
; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
@@ -252,6 +254,10 @@ declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_add
252254

253255
; CHECK-SPIRV: OpFunction
254256
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
257+
; CHECK-SPIRV: %[[#BroadcastVec2:]] = OpCompositeConstruct %[[#intv2]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
258+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec2]]
259+
; CHECK-SPIRV: %[[#BroadcastVec3:]] = OpCompositeConstruct %[[#intv3]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
260+
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec3]]
255261
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
256262
; CHECK-SPIRV: OpFunctionEnd
257263

@@ -263,12 +269,16 @@ define dso_local spir_kernel void @testWorkGroupBroadcast(i32 noundef %a, i32 ad
263269
entry:
264270
%0 = load i32, i32 addrspace(1)* %id, align 4
265271
%call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
272+
%call_v2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0)
273+
%call_v3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0, i32 noundef %0)
266274
store i32 %call, i32 addrspace(1)* %res, align 4
267275
%call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
268276
ret void
269277
}
270278

271279
declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
280+
declare spir_func i32 @_Z20work_group_broadcastjjj(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
281+
declare spir_func i32 @_Z20work_group_broadcastjjjj(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
272282
declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
273283

274284
; CHECK-SPIRV: OpFunction

0 commit comments

Comments
 (0)