Skip to content

[SPIR-V] Add implementation of the non-const G_BUILD_VECTOR and fix emission of the OpGroupBroadcast instruction #103050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR addresses a TODO in lib/Target/SPIRV/SPIRVInstructionSelector.cpp by adding implementation of the non-const G_BUILD_VECTOR, and fix emission of the OpGroupBroadcast instruction for the case when the ..._group_broadcast builtin has more than one local_id argument and OpGroupBroadcast requires a newly constructed vector with 2 or 3 components instead of originally passed series of local_id arguments.

This PR may resolve #97310 if the reason for the reported fail is an incorrectly generated OpGroupBroadcast instruction that was definitely a case.

Existing test is hardened and a new test is added to cover this special case of the OpGroupBroadcast instruction emission.

@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR addresses a TODO in lib/Target/SPIRV/SPIRVInstructionSelector.cpp by adding implementation of the non-const G_BUILD_VECTOR, and fix emission of the OpGroupBroadcast instruction for the case when the ..._group_broadcast builtin has more than one local_id argument and OpGroupBroadcast requires a newly constructed vector with 2 or 3 components instead of originally passed series of local_id arguments.

This PR may resolve #97310 if the reason for the reported fail is an incorrectly generated OpGroupBroadcast instruction that was definitely a case.

Existing test is hardened and a new test is added to cover this special case of the OpGroupBroadcast instruction emission.


Full diff: https://github.com/llvm/llvm-project/pull/103050.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+36-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+36-31)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+14-6)
  • (added) llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll (+152)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll (+10)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8fa5106cef32e9..09f06728d2d10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1135,6 +1135,35 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
                                                       : SPIRV::Scope::Workgroup;
   Register ScopeRegister = buildConstantIntReg(Scope, MIRBuilder, GR);
 
+  Register VecReg;
+  if (GroupBuiltin->Opcode == SPIRV::OpGroupBroadcast &&
+      Call->Arguments.size() > 2) {
+    // For OpGroupBroadcast "LocalId must be an integer datatype. It must be a
+    // scalar, a vector with 2 components, or a vector with 3 components.",
+    // meaning that we must create a vector from the function arguments if
+    // it's a work_group_broadcast(val, local_id_x, local_id_y) or
+    // work_group_broadcast(val, local_id_x, local_id_y, local_id_z) call.
+    Register ElemReg = Call->Arguments[1];
+    SPIRVType *ElemType = GR->getSPIRVTypeForVReg(ElemReg);
+    if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeInt)
+      report_fatal_error("Expect an integer <LocalId> argument");
+    unsigned VecLen = Call->Arguments.size() - 1;
+    VecReg = MRI->createGenericVirtualRegister(
+        LLT::fixed_vector(VecLen, MRI->getType(ElemReg)));
+    MRI->setRegClass(VecReg, &SPIRV::vIDRegClass);
+    SPIRVType *VecType =
+        GR->getOrCreateSPIRVVectorType(ElemType, VecLen, MIRBuilder);
+    GR->assignSPIRVTypeToVReg(VecType, VecReg, MIRBuilder.getMF());
+    auto MIB =
+        MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
+    for (unsigned i = 1; i < Call->Arguments.size(); i++) {
+      MIB.addUse(Call->Arguments[i]);
+      MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+    }
+    insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
+                      MIRBuilder.getMF().getRegInfo());
+  }
+
   // Build work/sub group instruction.
   auto MIB = MIRBuilder.buildInstr(GroupBuiltin->Opcode)
                  .addDef(GroupResultRegister)
@@ -1146,10 +1175,13 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
   if (Call->Arguments.size() > 0) {
     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
     MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
-    for (unsigned i = 1; i < Call->Arguments.size(); i++) {
-      MIB.addUse(Call->Arguments[i]);
-      MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
-    }
+    if (VecReg.isValid())
+      MIB.addUse(VecReg);
+    else
+      for (unsigned i = 1; i < Call->Arguments.size(); i++) {
+        MIB.addUse(Call->Arguments[i]);
+        MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+      }
   }
 
   // Build select instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c55235a04a607f..7681108e800a43 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -159,7 +159,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectBitreverse(Register ResVReg, const SPIRVType *ResType,
                         MachineInstr &I) const;
 
-  bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
+  bool selectBuildVector(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
   bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
@@ -405,7 +405,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
     return selectBitreverse(ResVReg, ResType, I);
 
   case TargetOpcode::G_BUILD_VECTOR:
-    return selectConstVector(ResVReg, ResType, I);
+    return selectBuildVector(ResVReg, ResType, I);
   case TargetOpcode::G_SPLAT_VECTOR:
     return selectSplatVector(ResVReg, ResType, I);
 
@@ -1457,35 +1457,6 @@ bool SPIRVInstructionSelector::selectFreeze(Register ResVReg,
   return false;
 }
 
-bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
-                                                 const SPIRVType *ResType,
-                                                 MachineInstr &I) const {
-  // TODO: only const case is supported for now.
-  assert(std::all_of(
-      I.operands_begin(), I.operands_end(), [this](const MachineOperand &MO) {
-        if (MO.isDef())
-          return true;
-        if (!MO.isReg())
-          return false;
-        SPIRVType *ConstTy = this->MRI->getVRegDef(MO.getReg());
-        assert(ConstTy && ConstTy->getOpcode() == SPIRV::ASSIGN_TYPE &&
-               ConstTy->getOperand(1).isReg());
-        Register ConstReg = ConstTy->getOperand(1).getReg();
-        const MachineInstr *Const = this->MRI->getVRegDef(ConstReg);
-        assert(Const);
-        return (Const->getOpcode() == TargetOpcode::G_CONSTANT ||
-                Const->getOpcode() == TargetOpcode::G_FCONSTANT);
-      }));
-
-  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
-                     TII.get(SPIRV::OpConstantComposite))
-                 .addDef(ResVReg)
-                 .addUse(GR.getSPIRVTypeID(ResType));
-  for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
-    MIB.addUse(I.getOperand(i).getReg());
-  return MIB.constrainAllUses(TII, TRI, RBI);
-}
-
 static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
                                        const SPIRVType *ResType) {
   Register OpReg = ResType->getOperand(2).getReg();
@@ -1551,6 +1522,40 @@ static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
   return false;
 }
 
+bool SPIRVInstructionSelector::selectBuildVector(Register ResVReg,
+                                                 const SPIRVType *ResType,
+                                                 MachineInstr &I) const {
+  unsigned N = 0;
+  if (ResType->getOpcode() == SPIRV::OpTypeVector)
+    N = GR.getScalarOrVectorComponentCount(ResType);
+  else if (ResType->getOpcode() == SPIRV::OpTypeArray)
+    N = getArrayComponentCount(MRI, ResType);
+  else
+    report_fatal_error("Cannot select G_BUILD_VECTOR with a non-vector result");
+  if (I.getNumExplicitOperands() - I.getNumExplicitDefs() != N)
+    report_fatal_error("G_BUILD_VECTOR and the result type are inconsistent");
+
+  // check if we may construct a constant vector
+  bool IsConst = true;
+  for (unsigned i = I.getNumExplicitDefs();
+       i < I.getNumExplicitOperands() && IsConst; ++i)
+    if (!isConstReg(MRI, I.getOperand(i).getReg()))
+      IsConst = false;
+
+  if (!IsConst && N < 2)
+    report_fatal_error(
+        "There must be at least two constituent operands in a vector");
+
+  auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
+                     TII.get(IsConst ? SPIRV::OpConstantComposite
+                                     : SPIRV::OpCompositeConstruct))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType));
+  for (unsigned i = I.getNumExplicitDefs(); i < I.getNumExplicitOperands(); ++i)
+    MIB.addUse(I.getOperand(i).getReg());
+  return MIB.constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
                                                  const SPIRVType *ResType,
                                                  MachineInstr &I) const {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 6838f4bf9410f0..7c158abba3c28c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -530,15 +530,23 @@ generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
           MachineInstr *ElemMI = MRI.getVRegDef(MI.getOperand(1).getReg());
           assert(ElemMI);
 
-          if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT)
+          if (ElemMI->getOpcode() == TargetOpcode::G_CONSTANT) {
             ElemTy = ElemMI->getOperand(1).getCImm()->getType();
-          else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT)
+          } else if (ElemMI->getOpcode() == TargetOpcode::G_FCONSTANT) {
             ElemTy = ElemMI->getOperand(1).getFPImm()->getType();
+          } else {
+            // There may be a case when we already know Reg's type.
+            MachineInstr *NextMI = MI.getNextNode();
+            if (!NextMI || NextMI->getOpcode() != SPIRV::ASSIGN_TYPE ||
+                NextMI->getOperand(1).getReg() != Reg)
+              llvm_unreachable("Unexpected opcode");
+          }
+          if (ElemTy)
+            Ty = VectorType::get(
+                ElemTy, MI.getNumExplicitOperands() - MI.getNumExplicitDefs(),
+                false);
           else
-            llvm_unreachable("Unexpected opcode");
-          unsigned NumElts =
-              MI.getNumExplicitOperands() - MI.getNumExplicitDefs();
-          Ty = VectorType::get(ElemTy, NumElts, false);
+            NeedAssignType = false;
         }
         if (NeedAssignType)
           insertAssignInstr(Reg, Ty, nullptr, GR, MIB, MRI);
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll
new file mode 100644
index 00000000000000..fd737e85a8ec50
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupBroadcast.ll
@@ -0,0 +1,152 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV: OpCapability Groups
+; CHECK-SPIRV-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#Int64Ty:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#Float32Ty:]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[#Vec2Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 2
+; CHECK-SPIRV-DAG: %[[#Vec3Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 3
+; CHECK-SPIRV-DAG: %[[#Vec2Int64Ty:]] = OpTypeVector %[[#Int64Ty]] 2
+; CHECK-SPIRV-DAG: %[[#C2:]] = OpConstant %[[#Int32Ty]] 2
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#Val:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#X:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#Y:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#Z:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#X]]
+; CHECK-SPIRV: %[[#XY:]] = OpCompositeConstruct %[[#Vec2Int32Ty]] %[[#X]] %[[#Y]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XY]]
+; CHECK-SPIRV: %[[#XYZ:]] = OpCompositeConstruct %[[#Vec3Int32Ty]] %[[#X]] %[[#Y]] %[[#Z]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Int32Ty]] %[[#C2]] %[[#Val]] %[[#XYZ]]
+define spir_kernel void @test_broadcast_xyz(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z) {
+entry:
+  %call1 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x)
+  %call2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y)
+  %call3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %x, i32 noundef %y, i32 noundef %z)
+  ret void
+}
+
+declare spir_func i32 @_Z20work_group_broadcastjj(i32, i32)
+declare spir_func i32 @_Z20work_group_broadcastjjj(i32, i32, i32)
+declare spir_func i32 @_Z20work_group_broadcastjjjj(i32, i32, i32, i32)
+
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpInBoundsPtrAccessChain
+; CHECK-SPIRV: %[[#LoadedVal:]] = OpLoad %[[#Float32Ty]] %[[#]]
+; CHECK-SPIRV: %[[#IdX:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 0
+; CHECK-SPIRV: %[[#IdY:]] = OpCompositeExtract %[[#Int64Ty]] %[[#]] 1
+; CHECK-SPIRV: %[[#LocIdsVec:]] = OpCompositeConstruct %[[#Vec2Int64Ty]] %[[#IdX]] %[[#IdY]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#Float32Ty]] %[[#C2]] %[[#LoadedVal]] %[[#LocIdsVec]]
+define spir_kernel void @test_wg_broadcast_2D(ptr addrspace(1) %input, ptr addrspace(1) %output) #0 !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 !spirv.ParameterDecorations !11 {
+entry:
+  %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
+  %1 = insertelement <3 x i64> undef, i64 %0, i32 0
+  %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
+  %3 = insertelement <3 x i64> %1, i64 %2, i32 1
+  %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
+  %5 = insertelement <3 x i64> %3, i64 %4, i32 2
+  %call = extractelement <3 x i64> %5, i32 0
+  %6 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
+  %7 = insertelement <3 x i64> undef, i64 %6, i32 0
+  %8 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
+  %9 = insertelement <3 x i64> %7, i64 %8, i32 1
+  %10 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
+  %11 = insertelement <3 x i64> %9, i64 %10, i32 2
+  %call1 = extractelement <3 x i64> %11, i32 1
+  %12 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
+  %13 = insertelement <3 x i64> undef, i64 %12, i32 0
+  %14 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
+  %15 = insertelement <3 x i64> %13, i64 %14, i32 1
+  %16 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
+  %17 = insertelement <3 x i64> %15, i64 %16, i32 2
+  %call2 = extractelement <3 x i64> %17, i32 0
+  %18 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+  %19 = insertelement <3 x i64> undef, i64 %18, i32 0
+  %20 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
+  %21 = insertelement <3 x i64> %19, i64 %20, i32 1
+  %22 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+  %23 = insertelement <3 x i64> %21, i64 %22, i32 2
+  %call3 = extractelement <3 x i64> %23, i32 0
+  %rem = urem i64 %call2, %call3
+  %24 = call spir_func i64 @_Z12get_group_idj(i32 0) #1
+  %25 = insertelement <3 x i64> undef, i64 %24, i32 0
+  %26 = call spir_func i64 @_Z12get_group_idj(i32 1) #1
+  %27 = insertelement <3 x i64> %25, i64 %26, i32 1
+  %28 = call spir_func i64 @_Z12get_group_idj(i32 2) #1
+  %29 = insertelement <3 x i64> %27, i64 %28, i32 2
+  %call4 = extractelement <3 x i64> %29, i32 1
+  %30 = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
+  %31 = insertelement <3 x i64> undef, i64 %30, i32 0
+  %32 = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
+  %33 = insertelement <3 x i64> %31, i64 %32, i32 1
+  %34 = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
+  %35 = insertelement <3 x i64> %33, i64 %34, i32 2
+  %call5 = extractelement <3 x i64> %35, i32 1
+  %rem6 = urem i64 %call4, %call5
+  %36 = call spir_func i64 @_Z15get_global_sizej(i32 0) #1
+  %37 = insertelement <3 x i64> undef, i64 %36, i32 0
+  %38 = call spir_func i64 @_Z15get_global_sizej(i32 1) #1
+  %39 = insertelement <3 x i64> %37, i64 %38, i32 1
+  %40 = call spir_func i64 @_Z15get_global_sizej(i32 2) #1
+  %41 = insertelement <3 x i64> %39, i64 %40, i32 2
+  %call7 = extractelement <3 x i64> %41, i32 0
+  %mul = mul i64 %call1, %call7
+  %add = add i64 %mul, %call
+  %arrayidx = getelementptr inbounds float, ptr addrspace(1) %input, i64 %add
+  %42 = load float, ptr addrspace(1) %arrayidx, align 4
+  %.splatinsert = insertelement <2 x i64> undef, i64 %rem, i32 0
+  %.splat = shufflevector <2 x i64> %.splatinsert, <2 x i64> undef, <2 x i32> zeroinitializer
+  %43 = insertelement <2 x i64> %.splat, i64 %rem6, i32 1
+  %44 = extractelement <2 x i64> %43, i32 0
+  %45 = extractelement <2 x i64> %43, i32 1
+  %call8 = call spir_func float @_Z20work_group_broadcastfmm(float %42, i64 %44, i64 %45) #2
+  %arrayidx9 = getelementptr inbounds float, ptr addrspace(1) %output, i64 %add
+  store float %call8, ptr addrspace(1) %arrayidx9, align 4
+  ret void
+}
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z13get_global_idj(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z12get_group_idj(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z14get_local_sizej(i32) #1
+
+; Function Attrs: nounwind willreturn memory(none)
+declare spir_func i64 @_Z15get_global_sizej(i32) #1
+
+; Function Attrs: convergent nounwind
+declare spir_func float @_Z20work_group_broadcastfmm(float, i64, i64) #2
+
+attributes #0 = { nounwind }
+attributes #1 = { nounwind willreturn memory(none) }
+attributes #2 = { convergent nounwind }
+
+!spirv.MemoryModel = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!spirv.Source = !{!1}
+!opencl.spir.version = !{!2}
+!opencl.ocl.version = !{!3}
+!opencl.used.extensions = !{!4}
+!opencl.used.optional.core.features = !{!5}
+!spirv.Generator = !{!6}
+
+!0 = !{i32 2, i32 2}
+!1 = !{i32 3, i32 300000}
+!2 = !{i32 2, i32 0}
+!3 = !{i32 3, i32 0}
+!4 = !{!"cl_khr_subgroups"}
+!5 = !{}
+!6 = !{i16 6, i16 14}
+!7 = !{i32 1, i32 1}
+!8 = !{!"none", !"none"}
+!9 = !{!"float*", !"float*"}
+!10 = !{!"", !""}
+!11 = !{!5, !5}
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
index 65795d23657c24..870d9b76d5f3eb 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/group_ops.ll
@@ -5,6 +5,8 @@
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-SPIRV-DAG: %[[#int:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#intv2:]] = OpTypeVector %[[#int]] 2
+; CHECK-SPIRV-DAG: %[[#intv3:]] = OpTypeVector %[[#int]] 3
 ; CHECK-SPIRV-DAG: %[[#float:]] = OpTypeFloat 32
 ; CHECK-SPIRV-DAG: %[[#ScopeCrossWorkgroup:]] = OpConstant %[[#int]] 0
 ; CHECK-SPIRV-DAG: %[[#ScopeWorkgroup:]] = OpConstant %[[#int]] 2
@@ -252,6 +254,10 @@ declare spir_func i32 @_Z21work_group_reduce_minj(i32 noundef) local_unnamed_add
 
 ; CHECK-SPIRV: OpFunction
 ; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue:]] %[[#BroadcastLocalId:]]
+; CHECK-SPIRV: %[[#BroadcastVec2:]] = OpCompositeConstruct %[[#intv2]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec2]]
+; CHECK-SPIRV: %[[#BroadcastVec3:]] = OpCompositeConstruct %[[#intv3]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]] %[[#BroadcastLocalId]]
+; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastVec3]]
 ; CHECK-SPIRV: %[[#]] = OpGroupBroadcast %[[#int]] %[[#ScopeCrossWorkgroup]] %[[#BroadcastValue]] %[[#BroadcastLocalId]]
 ; CHECK-SPIRV: OpFunctionEnd
 
@@ -263,12 +269,16 @@ define dso_local spir_kernel void @testWorkGroupBroadcast(i32 noundef %a, i32 ad
 entry:
   %0 = load i32, i32 addrspace(1)* %id, align 4
   %call = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0)
+  %call_v2 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0)
+  %call_v3 = call spir_func i32 @_Z20work_group_broadcastjj(i32 noundef %a, i32 noundef %0, i32 noundef %0, i32 noundef %0)
   store i32 %call, i32 addrspace(1)* %res, align 4
   %call1 = call spir_func i32 @__spirv_GroupBroadcast(i32 0, i32 noundef %a, i32 noundef %0)
   ret void
 }
 
 declare spir_func i32 @_Z20work_group_broadcastjj(i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @_Z20work_group_broadcastjjj(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
+declare spir_func i32 @_Z20work_group_broadcastjjjj(i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
 declare spir_func i32 @__spirv_GroupBroadcast(i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr
 
 ; CHECK-SPIRV: OpFunction

@michalpaszkowski
Copy link
Member

Thank you! Confirmed this pull request solves the issue in #97310!

./test_workgroups work_group_broadcast_2D
 Initializing random seed to 0.
Requesting Default device based on command line for platform index 0 and device index 0
Compute Device Name = Intel(R) Graphics [0xa780], Compute Device Vendor = Intel(R) Corporation, Compute Device Version = OpenCL 3.0 NEO , CL C Version = OpenCL C 1.2 
Device latest conformance version passed: v2022-04-22-00
Supports single precision denormals: YES
sizeof( void*) = 8  (host)
sizeof( void*) = 8  (device)
work_group_broadcast_2D...
work_group_broadcast_2D test passed
work_group_broadcast_2D passed
PASSED sub-test.
PASSED test.

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 2fc7a72 into llvm:main Aug 14, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIR-V] Missing/incorrect builtin implementation OpGroupBroadcast
3 participants