Skip to content

[SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension #119862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Dec 13, 2024

This PR adds the following features:

  • saturation and float rounding mode decorations,
  • arithmetic constrained floating-point intrinsics (strict_fadd, strict_fsub, strict_fmul, strict_fdiv, strict_frem, strict_fma and strict_fldexp),
  • and SPV_INTEL_float_controls2 extension,
  • using recent improvements of emit-intrinsics step, this PR also simplifies pre- and post-legalizer steps and improves instruction selection.

Copy link

github-actions bot commented Dec 13, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@VyacheslavLevytskyy VyacheslavLevytskyy changed the title [SPIR-V] Add saturation and float rounding mode decorations, several arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension [SPIR-V] Add saturation and float rounding mode decorations, a subset of arithmetic constrained floating-point intrinsics, and SPV_INTEL_float_controls2 extension Dec 13, 2024
@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review December 14, 2024 00:21
@llvmbot
Copy link
Member

llvmbot commented Dec 14, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds the following features:

  • saturation and float rounding mode decorations,
  • arithmetic constrained floating-point intrinsics (strict_fadd, strict_fsub, strict_fmul, strict_fdiv, strict_frem, strict_fma and strict_fldexp),
  • and SPV_INTEL_float_controls2 extension,
  • using recent improvements of emit-intrinsics step, this PR also simplifies pre- and post-legalizer steps and improves instruction selection.

Patch is 43.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119862.diff

18 Files Affected:

  • (modified) llvm/docs/SPIRVUsage.rst (+2)
  • (modified) llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp (+4-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+18-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+97-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+7-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+37-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+12-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+21-4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+9-22)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+3-33)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+11)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_empty.ll (+18)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_float_controls2/exec_mode_float_control_intel.ll (+74)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/integer-casts.ll (+28)
  • (added) llvm/test/CodeGen/SPIRV/llvm-intrinsics/constrained-arithmetic.ll (+44)
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 8f7ac71f8026b3..b7b3d21545168c 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -159,6 +159,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
      - Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
    * - ``SPV_INTEL_cache_controls``
      - Allows cache control information to be applied to memory access instructions.
+   * - ``SPV_INTEL_float_controls2``
+     - Adds execution modes and decorations to control floating-point computations.
    * - ``SPV_INTEL_function_pointers``
      - Allows translation of function pointers.
    * - ``SPV_INTEL_inline_assembly``
diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
index 42567f695395ef..68cc6a3a7aac1b 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
@@ -65,11 +65,10 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
   // If we define an output, and have at least one other argument.
   if (MCDesc.getNumDefs() == 1 && MCDesc.getNumOperands() >= 2) {
     // Check if we define an ID, and take a type as operand 1.
-    auto &DefOpInfo = MCDesc.operands()[0];
-    auto &FirstArgOpInfo = MCDesc.operands()[1];
-    return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
-           DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
-           FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
+    return MCDesc.operands()[0].RegClass >= 0 &&
+           MCDesc.operands()[1].RegClass >= 0 &&
+           MCDesc.operands()[0].RegClass != SPIRV::TYPERegClassID &&
+           MCDesc.operands()[1].RegClass == SPIRV::TYPERegClassID;
   }
   return false;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f4bfda4932b167..4bfa51e2cccdd8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -173,7 +173,8 @@ using namespace InstructionSet;
 
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+                                    std::string *Postfix) {
   const static std::string PassPrefix = "(anonymous namespace)::";
   std::string BuiltinName;
   // Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -231,10 +232,13 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
       "Convert|"
-      "UConvert|SConvert|FConvert|SatConvert).*)_R.*");
+      "UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
   std::smatch Match;
-  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 2)
+  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
     BuiltinName = Match[1].str();
+    if (Postfix)
+      *Postfix = Match[3].str();
+  }
 
   return BuiltinName;
 }
@@ -583,6 +587,15 @@ static Register buildScopeReg(Register CLScopeRegister,
   return buildConstantIntReg32(Scope, MIRBuilder, GR);
 }
 
+static void setRegClassIfNull(Register Reg, MachineRegisterInfo *MRI,
+                              SPIRVGlobalRegistry *GR) {
+  if (MRI->getRegClassOrNull(Reg))
+    return;
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(Reg);
+  MRI->setRegClass(Reg,
+                   SpvType ? GR->getRegClass(SpvType) : &SPIRV::iIDRegClass);
+}
+
 static Register buildMemSemanticsReg(Register SemanticsRegister,
                                      Register PtrRegister, unsigned &Semantics,
                                      MachineIRBuilder &MIRBuilder,
@@ -1160,7 +1173,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
         MIRBuilder.buildInstr(TargetOpcode::G_BUILD_VECTOR).addDef(VecReg);
     for (unsigned i = 1; i < Call->Arguments.size(); i++) {
       MIB.addUse(Call->Arguments[i]);
-      MRI->setRegClass(Call->Arguments[i], &SPIRV::iIDRegClass);
+      setRegClassIfNull(Call->Arguments[i], MRI, GR);
     }
     insertAssignInstr(VecReg, nullptr, VecType, GR, MIRBuilder,
                       MIRBuilder.getMF().getRegInfo());
@@ -1176,7 +1189,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
     MIB.addImm(GroupBuiltin->GroupOperation);
   if (Call->Arguments.size() > 0) {
     MIB.addUse(Arg0.isValid() ? Arg0 : Call->Arguments[0]);
-    MRI->setRegClass(Call->Arguments[0], &SPIRV::iIDRegClass);
+    setRegClassIfNull(Call->Arguments[0], MRI, GR);
     if (VecReg.isValid())
       MIB.addUse(VecReg);
     else
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 42b452db8b9fb4..0182d9652d18c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -20,7 +20,8 @@
 namespace llvm {
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
-std::string lookupBuiltinNameHelper(StringRef DemangledCall);
+std::string lookupBuiltinNameHelper(StringRef DemangledCall,
+                                    std::string *Postfix = nullptr);
 /// Lowers a builtin function call using the provided \p DemangledCall skeleton
 /// and external instruction \p Set.
 ///
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index fb05c1fdbd1e3b..45b39c51164795 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -36,6 +36,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_arbitrary_precision_integers},
         {"SPV_INTEL_cache_controls",
          SPIRV::Extension::Extension::SPV_INTEL_cache_controls},
+        {"SPV_INTEL_float_controls2",
+         SPIRV::Extension::Extension::SPV_INTEL_float_controls2},
         {"SPV_INTEL_global_variable_fpga_decorations",
          SPIRV::Extension::Extension::
              SPV_INTEL_global_variable_fpga_decorations},
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 2b623136e602e5..433956f44917fb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -216,6 +216,8 @@ class SPIRVEmitIntrinsics
   bool processFunctionPointers(Module &M);
   void parseFunDeclarations(Module &M);
 
+  void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
+
 public:
   static char ID;
   SPIRVEmitIntrinsics() : ModulePass(ID) {
@@ -1291,6 +1293,37 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
   }
 }
 
+static void createDecorationIntrinsic(Instruction *I, MDNode *Node,
+                                      IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  setInsertPointAfterDef(B, I);
+  B.CreateIntrinsic(Intrinsic::spv_assign_decoration, {I->getType()},
+                    {I, MetadataAsValue::get(Ctx, MDNode::get(Ctx, {Node}))});
+}
+
+static void createRoundingModeDecoration(Instruction *I,
+                                         unsigned RoundingModeDeco,
+                                         IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+  MDNode *RoundingModeNode = MDNode::get(
+      Ctx,
+      {ConstantAsMetadata::get(
+           ConstantInt::get(Int32Ty, SPIRV::Decoration::FPRoundingMode)),
+       ConstantAsMetadata::get(ConstantInt::get(Int32Ty, RoundingModeDeco))});
+  createDecorationIntrinsic(I, RoundingModeNode, B);
+}
+
+static void createSaturatedConversionDecoration(Instruction *I,
+                                                IRBuilder<> &B) {
+  LLVMContext &Ctx = I->getContext();
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+  MDNode *SaturatedConversionNode =
+      MDNode::get(Ctx, {ConstantAsMetadata::get(ConstantInt::get(
+                           Int32Ty, SPIRV::Decoration::SaturatedConversion))});
+  createDecorationIntrinsic(I, SaturatedConversionNode, B);
+}
+
 Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
   if (!Call.isInlineAsm())
     return &Call;
@@ -1312,6 +1345,40 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
   return &Call;
 }
 
+// Use a tip about rounding mode to create a decoration.
+void SPIRVEmitIntrinsics::useRoundingMode(ConstrainedFPIntrinsic *FPI,
+                                          IRBuilder<> &B) {
+  std::optional<RoundingMode> RM = FPI->getRoundingMode();
+  if (!RM.has_value())
+    return;
+  unsigned RoundingModeDeco = std::numeric_limits<unsigned>::max();
+  switch (RM.value()) {
+  default:
+    // ignore unknown rounding modes
+    break;
+  case RoundingMode::NearestTiesToEven:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+    break;
+  case RoundingMode::TowardNegative:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+    break;
+  case RoundingMode::TowardPositive:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+    break;
+  case RoundingMode::TowardZero:
+    RoundingModeDeco = SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+    break;
+  case RoundingMode::Dynamic:
+  case RoundingMode::NearestTiesToAway:
+    // TODO: check if supported
+    break;
+  }
+  if (RoundingModeDeco == std::numeric_limits<unsigned>::max())
+    return;
+  // Convert the tip about rounding mode into a decoration record.
+  createRoundingModeDecoration(FPI, RoundingModeDeco, B);
+}
+
 Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
   BasicBlock *ParentBB = I.getParent();
   IRBuilder<> B(ParentBB);
@@ -1809,6 +1876,18 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   return true;
 }
 
+static unsigned roundingModeMDToDecorationConst(StringRef S) {
+  if (S == "rte")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
+  if (S == "rtz")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
+  if (S == "rtp")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
+  if (S == "rtn")
+    return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
+  return std::numeric_limits<unsigned>::max();
+}
+
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
   // TODO: extend the list of functions with known result types
@@ -1826,8 +1905,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
       Function *CalledF = CI->getCalledFunction();
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
+      std::string Postfix;
       if (DemangledName.length() > 0)
-        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
+        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
       auto ResIt = ResTypeWellKnown.find(DemangledName);
       if (ResIt != ResTypeWellKnown.end()) {
         IsKnown = true;
@@ -1839,6 +1919,19 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           break;
         }
       }
+      // check if a floating rounding mode info is present
+      StringRef S = Postfix;
+      SmallVector<StringRef, 8> Parts;
+      S.split(Parts, "_", -1, false);
+      if (Parts.size() > 1) {
+        // Convert the info about rounding mode into a decoration record.
+        unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
+        if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
+          createRoundingModeDecoration(CI, RoundingModeDeco, B);
+        // Check if the SaturatedConversion info is present.
+        if (Parts[1] == "sat")
+          createSaturatedConversionDecoration(CI, B);
+      }
     }
   }
 
@@ -2264,6 +2357,9 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
     // already, and force it to be i8 if not
     if (Postpone && !GR->findAssignPtrTypeInstr(I))
       insertAssignPtrTypeIntrs(I, B, true);
+
+    if (auto *FPI = dyn_cast<ConstrainedFPIntrinsic>(I))
+      useRoundingMode(FPI, B);
   }
 
   // Pass backward: use instructions results to specify/update/cast operands
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 5f72a41ddb8647..3e913646d57c80 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -126,14 +126,14 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
   Width = adjustOpTypeIntWidth(Width);
   const SPIRVSubtarget &ST =
       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
-  if (ST.canUseExtension(
-          SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
-    MIRBuilder.buildInstr(SPIRV::OpExtension)
-        .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
-    MIRBuilder.buildInstr(SPIRV::OpCapability)
-        .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
-  }
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    if (ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
+      MIRBuilder.buildInstr(SPIRV::OpExtension)
+          .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
+      MIRBuilder.buildInstr(SPIRV::OpCapability)
+          .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
+    }
     return MIRBuilder.buildInstr(SPIRV::OpTypeInt)
         .addDef(createTypeVReg(MIRBuilder))
         .addImm(Width)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index d95803fea56a58..1bc35c6e57a4f6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -491,16 +491,20 @@ def OpFNegate: UnOpTyped<"OpFNegate", 127, fID, fneg>;
 def OpFNegateV: UnOpTyped<"OpFNegate", 127, vfID, fneg>;
 defm OpIAdd: BinOpTypedGen<"OpIAdd", 128, add, 0, 1>;
 defm OpFAdd: BinOpTypedGen<"OpFAdd", 129, fadd, 1, 1>;
+defm OpStrictFAdd: BinOpTypedGen<"OpFAdd", 129, strict_fadd, 1, 1>;
 
 defm OpISub: BinOpTypedGen<"OpISub", 130, sub, 0, 1>;
 defm OpFSub: BinOpTypedGen<"OpFSub", 131, fsub, 1, 1>;
+defm OpStrictFSub: BinOpTypedGen<"OpFSub", 131, strict_fsub, 1, 1>;
 
 defm OpIMul: BinOpTypedGen<"OpIMul", 132, mul, 0, 1>;
 defm OpFMul: BinOpTypedGen<"OpFMul", 133, fmul, 1, 1>;
+defm OpStrictFMul: BinOpTypedGen<"OpFMul", 133, strict_fmul, 1, 1>;
 
 defm OpUDiv: BinOpTypedGen<"OpUDiv", 134, udiv, 0, 1>;
 defm OpSDiv: BinOpTypedGen<"OpSDiv", 135, sdiv, 0, 1>;
 defm OpFDiv: BinOpTypedGen<"OpFDiv", 136, fdiv, 1, 1>;
+defm OpStrictFDiv: BinOpTypedGen<"OpFDiv", 136, strict_fdiv, 1, 1>;
 
 defm OpUMod: BinOpTypedGen<"OpUMod", 137, urem, 0, 1>;
 defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
@@ -508,6 +512,8 @@ defm OpSRem: BinOpTypedGen<"OpSRem", 138, srem, 0, 1>;
 def OpSMod: BinOp<"OpSMod", 139>;
 
 defm OpFRem: BinOpTypedGen<"OpFRem", 140, frem, 1, 1>;
+defm OpStrictFRem: BinOpTypedGen<"OpFRem", 140, strict_frem, 1, 1>;
+
 def OpFMod: BinOp<"OpFMod", 141>;
 
 def OpVectorTimesScalar: BinOp<"OpVectorTimesScalar", 142>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b64030508cfc11..856caf2074fba4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -61,6 +61,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
   /// We need to keep track of the number we give to anonymous global values to
   /// generate the same name every time when this is needed.
   mutable DenseMap<const GlobalValue *, unsigned> UnnamedGlobalIDs;
+  SmallPtrSet<MachineInstr *, 8> DeadMIs;
 
 public:
   SPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -382,6 +383,24 @@ static bool isImm(const MachineOperand &MO, MachineRegisterInfo *MRI);
 // Defined in SPIRVLegalizerInfo.cpp.
 extern bool isTypeFoldingSupported(unsigned Opcode);
 
+bool isDead(const MachineInstr &MI, const MachineRegisterInfo &MRI) {
+  for (const auto &MO : MI.all_defs()) {
+    Register Reg = MO.getReg();
+    if (Reg.isPhysical() || !MRI.use_nodbg_empty(Reg))
+      return false;
+  }
+  if (MI.getOpcode() == TargetOpcode::LOCAL_ESCAPE || MI.isFakeUse() ||
+      MI.isLifetimeMarker())
+    return false;
+  if (MI.isPHI())
+    return true;
+  if (MI.mayStore() || MI.isCall() ||
+      (MI.mayLoad() && MI.hasOrderedMemoryRef()) || MI.isPosition() ||
+      MI.isDebugInstr() || MI.isTerminator() || MI.isJumpTableDebugInfo())
+    return false;
+  return true;
+}
+
 bool SPIRVInstructionSelector::select(MachineInstr &I) {
   resetVRegsType(*I.getParent()->getParent());
 
@@ -404,8 +423,11 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
           }
         });
         assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
-        if (Res)
+        if (Res) {
+          if (!isTriviallyDead(*Def, *MRI) && isDead(*Def, *MRI))
+            DeadMIs.insert(Def);
           return Res;
+        }
       }
       MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
       MRI->replaceRegWith(SrcReg, DstReg);
@@ -418,6 +440,15 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     return constrainSelectedInstRegOperands(I, TII, TRI, RBI);
   }
 
+  if (DeadMIs.contains(&I)) {
+    // if the instruction has been already made dead by folding it away
+    // erase it
+    LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
+    salvageDebugInfo(*MRI, I);
+    I.eraseFromParent();
+    return true;
+  }
+
   if (I.getNumOperands() != I.getNumExplicitOperands()) {
     LLVM_DEBUG(errs() << "Generic instr has unexpected implicit operands\n");
     return false;
@@ -557,9 +588,13 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_UCMP:
     return selectSUCmp(ResVReg, ResType, I, false);
 
+  case TargetOpcode::G_STRICT_FMA:
   case TargetOpcode::G_FMA:
     return selectExtInst(ResVReg, ResType, I, CL::fma, GL::Fma);
 
+  case TargetOpcode::G_STRICT_FLDEXP:
+    return selectExtInst(ResVReg, ResType, I, CL::ldexp);
+
   case TargetOpcode::G_FPOW:
     return selectExtInst(ResVReg, ResType, I, CL::pow, GL::Pow);
   case TargetOpcode::G_FPOWI:
@@ -618,6 +653,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
   case TargetOpcode::G_FTANH:
     return selectExtInst(ResVReg, ResType, I, CL::tanh, GL::Tanh);
 
+  case TargetOpcode::G_STRICT_FSQRT:
   case TargetOpcode::G_FSQRT:
     return selectExtInst(ResVReg, ResType, I, CL::sqrt, GL::Sqrt);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 7230e0e6b9fca1..b22027cd2cb931 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -24,19 +24,25 @@ using namespace llvm;
 using namespace llvm::LegalizeActions;
 using namespace llvm::LegalityPredicates;
 
+// clang-format off
 static const std::set<unsigned> TypeFoldingSupportingOpcs = {
     TargetOpcode::G_ADD,
     TargetOpcode::G_FADD,
+    TargetOpcode::G_STRICT_FADD,
     TargetOpcode::G_SUB,
     TargetOpcode::G_FSUB,
+    TargetOpcode::G_STRICT_FSUB,
     TargetOpcode::G_MUL,
     TargetOpcode::G_FMUL,
+    TargetOpcode::G_STRICT_FMUL,
     TargetOpcode::G_SDIV,
     TargetOpcode::G_UDIV,
     TargetOpcode::G_FDIV,
+    TargetOpcode::G_STRICT_FDIV,
     TargetOpcode::G_SREM,
     TargetOpcode::G_UREM,
     TargetOpcode::G_FREM,
+    TargetOpcode::G_STRICT_FREM,
     TargetOpcode::G_FNEG,
     TargetOpcode::G_CONSTANT,
     TargetOpcode::G_FCONSTANT,
@@ -49,6 +55,7 @@ static const std::set<unsigned> TypeFoldingSupportingOpcs = {
     TargetOpcode::G_SELECT,
     TargetOpcode::G_EXTRACT_VECTOR_ELT,
 };
+// clang-format on
 
 bool isTypeFoldingSupported(unsigned Opcode) {
   return TypeFoldingSupportingOpcs.count(Opcode) > 0;
@@ -219,7 +226,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
       .legalFor(allIntScalarsAndVectors)
       .legalIf(extendedScalarsAndVectors);
 
-  getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
...
[truncated]

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 978de2d into llvm:main Dec 16, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants