-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Add implementation of the non-const G_BUILD_VECTOR and fix emission of the OpGroupBroadcast instruction #103050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIR-V] Add implementation of the non-const G_BUILD_VECTOR and fix emission of the OpGroupBroadcast instruction #103050
Conversation
…f the OpGroupBroadcast instruction
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesThis PR addresses a TODO in lib/Target/SPIRV/SPIRVInstructionSelector.cpp by adding implementation of the non-const G_BUILD_VECTOR, and fix emission of the OpGroupBroadcast instruction for the case when the This PR may resolve #97310 if the reason for the reported fail is an incorrectly generated OpGroupBroadcast instruction that was definitely a case. Existing test is hardened and a new test is added to cover this special case of the OpGroupBroadcast instruction emission. Full diff: https://github.com/llvm/llvm-project/pull/103050.diff 5 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8fa5106cef32e9..09f06728d2d10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1135,6 +1135,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
: SPIRV::Scope::Workgroup;
Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
+ Register VecReg;
+ if (GroupBuiltin->Opcode == SPIRV::OpGroupBroadcast &&
+ Call->Arguments.size() > 2) {
+ // For OpGroupBroadcast "LocalId must be an integer datatype. It must be a
+ // scalar, a vector with 2 components, or a vector with 3 components.",
+ // meaning that we must create a vector from the function arguments if
+ // it's a work_group_broadcast(val, local_id_x, local_id_y) or
+ // work_group_broadcast(val, local_id_x, local_id_y, local_id_z) call.
+ Register ElemReg = Call->Arguments[1];
+ SPIRVType *ElemType = GR->getSPIRVTypeForVReg(ElemReg);
+ if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeInt)
+ report_fatal_error("Expect an integer <LocalId> argument");
+ unsigned VecLen = Call->Arguments.size() - 1;
+ VecReg = MRI->createGenericVirtualRegister(
+ LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
+ MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
+ SPIRVType *VecType =
+ GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder);
+ GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
+ auto MIB =
+ MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
+ for (unsigned i = 1; i < Call->Arguments.size(); i++) {
+ MIB.addUse(Call->Arguments[i]);
+ MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+ }
+ insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
+ MIRBuilder.getMF().getRegInfo());
+ }
+
// Build work/sub group instruction.
auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
.addDef(GroupResultRegister)
@@ -1146,10 +1175,13 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
if (Call->Arguments.size() > 0) {
MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
- for (unsigned i = 1; i < Call->Arguments.size(); i++) {
- MIB.addUse(Call->Arguments[i]);
- MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
- }
+ if (VecReg.isValid())
+ MIB.addUse(VecReg);
+ else
+ for (unsigned i = 1; i < Call->Arguments.size(); i++) {
+ MIB.addUse(Call->Arguments[i]);
+ MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+ }
}
// Build select instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c55235a04a607f..7681108e800a43 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -159,7 +159,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectBitreverse(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
- bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
+ bool selectBuildVector(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -405,7 +405,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
return selectBitreverse(ResVReg, ResType, I);
case TargetOpcode::G_BUILD_VECTOR:
- return selectConstVector(ResVReg, ResType, I);
+ return selectBuildVector(ResVReg, ResType, I);
case TargetOpcode::G_SPLAT_VECTOR:
return selectSplatVector(ResVReg, ResType, I);
@@ -1457,35 +1457,6 @@ bool SPIRVInstructionSelector::selectFreeze(Register ResVReg,
return false;
}
-bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I) const {
- // TODO: only const case is supported for now.
- assert(std::all_of(
- I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) {
- if (MO.isDef())
- return true;
- if (!MO.isReg())
- return false;
- SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg());
- assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
- ConstTy->getOperand(1).isReg());
- Register ConstReg = ConstTy->getOperand(1).getReg();
- const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
- assert(Const);
- return (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
- Const->getOpcode() == TargetOpcode::G_FCONSTANT);
- }));
-
- auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
- TII.get(SPIRV::OpConstantComposite))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType));
- for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
- MIB.addUse(I.getOperand(i).getReg());
- return MIB.constrainAllUses(TII, TRI, RBI);
-}
-
static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
const SPIRVType *ResType) {
Register OpReg = ResType->getOperand(2).getReg();
@@ -1551,6 +1522,40 @@ static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
return false;
}
+bool SPIRVInstructionSelector::selectBuildVector(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ unsigned N = 0;
+ if (ResType->getOpcode() == SPIRV::OpTypeVector)
+ N = GR.getScalarOrVectorComponentCount(ResType);
+ else if (ResType->getOpcode() == SPIRV::OpTypeArray)
+ N = getArrayComponentCount(MRI, ResType);
+ else
+ report_fatal_error("Cannot select G_BUILD_VECTOR with a non-vector result");
+ if (I.getNumExplicitOperands() - I.getNumExplicitDefs() != N)
+ report_fatal_error("G_BUILD_VECTOR and the result type are inconsistent");
+
+ // check if we may construct a constant vector
+ bool IsConst = true;
+ for (unsigned i = I.getNumExplicitDefs();
+ i < I.getNumExplicitOperands() && IsConst; ++i)
+ if (!isConstReg(MRI, I.getOperand(i).getReg()))
+ IsConst = false;
+
+ if (!IsConst && N < 2)
+ report_fatal_error(
+ "There must be at least two constituent operands in a vector");
+
+ auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+ TII.get(IsConst ? SPIRV::OpConstantComposite
+ : SPIRV::OpCompositeConstruct))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType));
+ for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
+ MIB.addUse(I.getOperand(i).getReg());
+ return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 6838f4bf9410f0..7c158abba3c28c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -530,15 +530,23 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
assert(ElemMI);
- if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
+ if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
ElemTy = ElemMI->getOperand(1).getCImm()->getType();
- else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
+ } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
+ } else {
+ // There may be a case when we already know Reg's type.
+ MachineInstr *NextMI = MI.getNextNode();
+ if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
+ NextMI->getOperand(1).getReg() != Reg)
+ llvm_unreachable("Unexpected opcode");
+ }
+ if (ElemTy)
+ Ty = VectorType::get(
+ ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
+ false);
else
- llvm_unreachable("Unexpected opcode");
- unsigned NumElts =
- MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
- Ty = VectorType::get(ElemTy, NumElts, false);
+ NeedAssignType = false;
}
if (NeedAssignType)
insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll
new file mode 100644
index 00000000000000..fd737e85a8ec50
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll
@@ -0,0 +1,152 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability Groups
+; CHECK-SPIRV-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#Int64Ty:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#Float32Ty:]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[#Vec2Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 2
+; CHECK-SPIRV-DAG: %[[#Vec3Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 3
+; CHECK-SPIRV-DAG: %[[#Vec2Int64Ty:]] = OpTypeVector %[[#Int64Ty]] 2
+; CHECK-SPIRV-DAG: %[[#C2:]] = OpConstant %[[#Int32Ty]] 2
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Val:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#X:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#Y:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#Z:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#X]]
+; CHECK-SPIRV: %[[#XY:]] = OpCompositeConstruct %[[#Vec2Int32Ty]] %[[#X]] %[[#Y]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XY]]
+; CHECK-SPIRV: %[[#XYZ:]] = OpCompositeConstruct %[[#Vec3Int32Ty]] %[[#X]] %[[#Y]] %[[#Z]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XYZ]]
+define spir_kernel void @test_broadcast_xyz(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z) {
+entry:
+ %call1 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x)
+ %call2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y)
+ %call3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z)
+ ret void
+}
+
+declare spir_func i32 @_Z20work_group_broadcastjj(i32, i32)
+declare spir_func i32 @_Z20work_group_broadcastjjj(i32, i32, i32)
+declare spir_func i32 @_Z20work_group_broadcastjjjj(i32, i32, i32, i32)
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpInBoundsPtrAccessChain
+; CHECK-SPIRV: %[[#LoadedVal:]] = OpLoad %[[#Float32Ty]] %[[#]]
+; CHECK-SPIRV: %[[#IdX:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 0
+; CHECK-SPIRV: %[[#IdY:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 1
+; CHECK-SPIRV: %[[#LocIdsVec:]] = OpCompositeConstruct %[[#Vec2Int64Ty]] %[[#IdX]] %[[#IdY]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Float32Ty]] %[[#C2]] %[[#LoadedVal]] %[[#LocIdsVec]]
+define spir_kernel void @test_wg_broadcast_2D(ptr addrspace(1) %input, ptr addrspace(1) %output) #0 !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 !spirv.ParameterDecorations !11 {
+entry:
+ %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
+ %1 = insertelement <3 x i64> undef, i64 %0, i32 0
+ %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
+ %3 = insertelement <3 x i64> %1, i64 %2, i32 1
+ %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
+ %5 = insertelement <3 x i64> %3, i64 %4, i32 2
+ %call = extractelement <3 x i64> %5, i32 0
+ %6 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
+ %7 = insertelement <3 x i64> undef, i64 %6, i32 0
+ %8 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
+ %9 = insertelement <3 x i64> %7, i64 %8, i32 1
+ %10 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
+ %11 = insertelement <3 x i64> %9, i64 %10, i32 2
+ %call1 = extractelement <3 x i64> %11, i32 1
+ %12 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
+ %13 = insertelement <3 x i64> undef, i64 %12, i32 0
+ %14 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
+ %15 = insertelement <3 x i64> %13, i64 %14, i32 1
+ %16 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
+ %17 = insertelement <3 x i64> %15, i64 %16, i32 2
+ %call2 = extractelement <3 x i64> %17, i32 0
+ %18 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+ %19 = insertelement <3 x i64> undef, i64 %18, i32 0
+ %20 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
+ %21 = insertelement <3 x i64> %19, i64 %20, i32 1
+ %22 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+ %23 = insertelement <3 x i64> %21, i64 %22, i32 2
+ %call3 = extractelement <3 x i64> %23, i32 0
+ %rem = urem i64 %call2, %call3
+ %24 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
+ %25 = insertelement <3 x i64> undef, i64 %24, i32 0
+ %26 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
+ %27 = insertelement <3 x i64> %25, i64 %26, i32 1
+ %28 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
+ %29 = insertelement <3 x i64> %27, i64 %28, i32 2
+ %call4 = extractelement <3 x i64> %29, i32 1
+ %30 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+ %31 = insertelement <3 x i64> undef, i64 %30, i32 0
+ %32 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
+ %33 = insertelement <3 x i64> %31, i64 %32, i32 1
+ %34 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+ %35 = insertelement <3 x i64> %33, i64 %34, i32 2
+ %call5 = extractelement <3 x i64> %35, i32 1
+ %rem6 = urem i64 %call4, %call5
+ %36 = call spir_func i64 @_Z15get_global_sizej(i32 0) #1
+ %37 = insertelement <3 x i64> undef, i64 %36, i32 0
+ %38 = call spir_func i64 @_Z15get_global_sizej(i32 1) #1
+ %39 = insertelement <3 x i64> %37, i64 %38, i32 1
+ %40 = call spir_func i64 @_Z15get_global_sizej(i32 2) #1
+ %41 = insertelement <3 x i64> %39, i64 %40, i32 2
+ %call7 = extractelement <3 x i64> %41, i32 0
+ %mul = mul i64 %call1, %call7
+ %add = add i64 %mul, %call
+ %arrayidx = getelementptr inbounds float, ptr addrspace(1) %input, i64 %add
+ %42 = load float, ptr addrspace(1) %arrayidx, align 4
+ %.splatinsert = insertelement <2 x i64> undef, i64 %rem, i32 0
+ %.splat = shufflevector <2 x i64> %.splatinsert, <2 x i64> undef, <2 x i32> zeroinitializer
+ %43 = insertelement <2 x i64> %.splat, i64 %rem6, i32 1
+ %44 = extractelement <2 x i64> %43, i32 0
+ %45 = extractelement <2 x i64> %43, i32 1
+ %call8 = call spir_func float @_Z20work_group_broadcastfmm(float %42, i64 %44, i64 %45) #2
+ %arrayidx9 = getelementptr inbounds float, ptr addrspace(1) %output, i64 %add
+ store float %call8, ptr addrspace(1) %arrayidx9, align 4
+ ret void
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z13get_global_idj(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z12get_group_idj(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z14get_local_sizej(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z15get_global_sizej(i32) #1
+
+; Function Attrs: convergent nounwind
+declare spir_func float @_Z20work_group_broadcastfmm(float, i64, i64) #2
+
+attributes #0 = { nounwind }
+attributes #1 = { nounwind willreturn memory(none) }
+attributes #2 = { convergent nounwind }
+
+!spirv.MemoryModel = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!spirv.Source = !{!1}
+!opencl.spir.version = !{!2}
+!opencl.ocl.version = !{!3}
+!opencl.used.extensions = !{!4}
+!opencl.used.optional.core.features = !{!5}
+!spirv.Generator = !{!6}
+
+!0 = !{i32 2, i32 2}
+!1 = !{i32 3, i32 300000}
+!2 = !{i32 2, i32 0}
+!3 = !{i32 3, i32 0}
+!4 = !{!"cl_khr_subgroups"}
+!5 = !{}
+!6 = !{i16 6, i16 14}
+!7 = !{i32 1, i32 1}
+!8 = !{!"none", !"none"}
+!9 = !{!"float*", !"float*"}
+!10 = !{!"", !""}
+!11 = !{!5, !5}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index 65795d23657c24..870d9b76d5f3eb 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -5,6 +5,8 @@
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#intv2:]] = OpTypeVector %[[#int]] 2
+; CHECK-SPIRV-DAG: %[[#intv3:]] = OpTypeVector %[[#int]] 3
; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
@@ -252,6 +254,10 @@ declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_add
; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
+; CHECK-SPIRV: %[[#BroadcastVec2:]] = OpCompositeConstruct %[[#intv2]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec2]]
+; CHECK-SPIRV: %[[#BroadcastVec3:]] = OpCompositeConstruct %[[#intv3]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec3]]
; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
; CHECK-SPIRV: OpFunctionEnd
@@ -263,12 +269,16 @@ define dso_local spir_kernel void @testWorkGroupBroadcast(i32 noundef %a, i32 ad
entry:
%0 = load i32, i32 addrspace(1)* %id, align 4
%call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
+ %call_v2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0)
+ %call_v3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0, i32 noundef %0)
store i32 %call, i32 addrspace(1)* %res, align 4
%call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
ret void
}
declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @_Z20work_group_broadcastjjj(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @_Z20work_group_broadcastjjjj(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
; CHECK-SPIRV: OpFunction
|
Thank you! Confirmed this pull request solves the issue in #97310!
|
This PR addresses a TODO in lib/Target/SPIRV/SPIRVInstructionSelector.cpp by adding implementation of the non-const G_BUILD_VECTOR, and fix emission of the OpGroupBroadcast instruction for the case when the
..._group_broadcast
builtin has more than onelocal_id
argument andOpGroupBroadcast
requires a newly constructed vector with 2 or 3 components instead of originally passed series oflocal_id
arguments.This PR may resolve #97310 if the reason for the reported fail is an incorrectly generated OpGroupBroadcast instruction that was definitely a case.
Existing test is hardened and a new test is added to cover this special case of the OpGroupBroadcast instruction emission.