Skip to content

[mlir][nvvm]Add support for grid_constant attribute on LLVM function arguments #78228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 12, 2024

Conversation

rishisurendran
Copy link
Contributor

@rishisurendran rishisurendran commented Jan 16, 2024

Add support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute.
Generate LLVM metadata for functions with nvvm.grid_constant arguments. The metadata node is a list of integers, where each integer n denotes that the nth parameter has the
grid_constant annotation (numbering from 1). The generated metadata node will be handled by NVVM compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties for documentation on grid_constant property.

This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters 

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-mlir

Author: Rishi Surendran (rishisurendran)

Changes

Add support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute.
Generate LLVM metadata for functions with nvvm.grid_constant arguments. The metadata node is a list of integers, where each integer n denotes that the nth parameter has the
grid_constant annotation (numbering from 1).

This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters 


Full diff: https://github.com/llvm/llvm-project/pull/78228.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+13)
  • (modified) mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h (+26)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+28)
  • (modified) mlir/lib/Target/LLVMIR/AttrKindDetail.h (+13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+57)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+32-24)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+26)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f986..1fc5ee2c32bd492 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
     /// Get the name of the attribute used to annotate max number of
     /// registers that can be allocated per thread.
     static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+    /// Get the name of the attribute used to annotate kernel arguments that
+    /// are grid constants.
+    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+    /// Verify an attribute from this dialect on the argument at 'argIndex' for
+    /// the region at 'regionIndex' on the given operation. Returns failure if
+    /// the verification failed, success otherwise. This hook may optionally be
+    /// invoked from any operation containing a region.
+    LogicalResult verifyRegionArgAttribute(Operation *,
+                                           unsigned regionIndex,
+                                           unsigned argIndex,
+                                           NamedAttribute) override;
   }];
 
   let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80fa..55358ebc6e86efc 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 #define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
 namespace mlir {
 namespace LLVM {
 class ModuleTranslation;
+class LLVMFuncOp;
 } // namespace LLVM
 
 /// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
                  LLVM::ModuleTranslation &moduleTranslation) const {
     return success();
   }
+
+  /// Hook for derived dialect interface to translate or act on a derived
+  /// dialect attribute that appears on a function parameter. This gets called
+  /// after the function operation has been translated.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attr,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    return success();
+  }
 };
 
 /// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
     }
     return success();
   }
+
+  /// Acts on the given function operation using the interface implemented by
+  /// the dialect of one of the function parameter attributes.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    if (const LLVMTranslationDialectInterface *iface =
+            getInterfaceFor(attribute.getNameDialect())) {
+      return iface->convertParameterAttr(function, argIdx, attribute,
+                                         moduleTranslation);
+    }
+    return success();
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d5..f0012bf875511ee 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -326,8 +326,8 @@ class ModuleTranslation {
   convertDialectAttributes(Operation *op,
                            ArrayRef<llvm::Instruction *> instructions);
 
-  /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+  FailureOr<llvm::AttrBuilder>
+  convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
   /// Original and translated module.
   Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc02..dc7816318131e41 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
   return success();
 }
 
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIndex,
+                                                    unsigned argIndex,
+                                                    NamedAttribute argAttr) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+
+  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+  auto attrName = argAttr.getName();
+  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+    if (!isKernel)
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute must be present only on kernel arguments.";
+    if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+      return op->emitError()
+             << "'" << attrName << "' must be a unit attribute.";
+    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute requires the argument to also have attribute '"
+             << LLVM::LLVMDialect::getByValAttrName() << "'.";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56eb..55a364856bd6f99 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
   return kindNamePairs;
 }
 
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+  static auto attrNameToKindMapping = []() {
+    static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+        nameKindMap;
+    for (auto kindNamePair : getAttrKindToNameMapping()) {
+      nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+    }
+    return nameKindMap;
+  }();
+  return attrNameToKindMapping;
+}
+
 } // namespace detail
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f4..5e1712527d70151 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
     }
     return success();
   }
+
+  LogicalResult
+  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const final {
+
+    llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+    llvm::Function *llvmFunc =
+        moduleTranslation.lookupFunction(funcOp.getName());
+    auto nvvmAnnotations =
+        moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+    if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+      llvm::MDNode *gridConstantMetaData = nullptr;
+
+      // Check if a 'grid_constant' metadata node exists for the given function
+      for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+        auto opnd = nvvmAnnotations->getOperand(i);
+        if (opnd->getNumOperands() == 3 &&
+            opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+            opnd->getOperand(1) ==
+                llvm::MDString::get(llvmContext, "grid_constant")) {
+          gridConstantMetaData = opnd;
+          break;
+        }
+      }
+
+      // 'grid_constant' is a function-level meta data node with a list of
+      // integers, where each integer n denotes that the nth parameter has the
+      // grid_constant annotation (numbering from 1). This requires aggregating
+      // the indices of the individual parameters that have this attribute.
+      llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+      if (gridConstantMetaData == nullptr) {
+        // Create a new 'grid_constant' metadata node
+        SmallVector<llvm::Metadata *> gridConstMetadata = {
+            llvm::ValueAsMetadata::getConstant(
+                llvm::ConstantInt::get(i32, argIdx + 1))};
+        llvm::Metadata *llvmMetadata[] = {
+            llvm::ValueAsMetadata::get(llvmFunc),
+            llvm::MDString::get(llvmContext, "grid_constant"),
+            llvm::MDNode::get(llvmContext, gridConstMetadata)};
+        llvm::MDNode *llvmMetadataNode =
+            llvm::MDNode::get(llvmContext, llvmMetadata);
+        nvvmAnnotations->addOperand(llvmMetadataNode);
+      } else {
+        // Append argIdx + 1 to the 'grid_constant' argument list
+        if (auto argList =
+                dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+          auto clonedArgList = argList->clone();
+          clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+              llvm::ConstantInt::get(i32, argIdx + 1))));
+          gridConstantMetaData->replaceOperandWith(
+              2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+        }
+      }
+    }
+    return success();
+  }
 };
 } // namespace
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba1..574dbfa177b9bb3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   llvmFunc->setMemoryEffects(newMemEffects);
 }
 
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+                                         DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
-  for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
-    Attribute attr = paramAttrs.get(mlirName);
-    // Skip attributes that are not present.
-    if (!attr)
-      continue;
-
-    // NOTE: C++17 does not support capturing structured bindings.
-    llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
-    llvm::TypeSwitch<Attribute>(attr)
-        .Case<TypeAttr>([&](auto typeAttr) {
-          attrBuilder.addTypeAttr(llvmKindCap,
-                                  convertType(typeAttr.getValue()));
-        })
-        .Case<IntegerAttr>([&](auto intAttr) {
-          attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
-        })
-        .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+
+      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+          .Case<TypeAttr>([&](auto typeAttr) {
+            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+          })
+          .Case<IntegerAttr>([&](auto intAttr) {
+            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+          })
+          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+    } else if (namedAttr.getNameDialect()) {
+      if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+        return failure();
+    }
   }
 
   return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert result attributes.
     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
       DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
-      llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+      FailureOr<llvm::AttrBuilder> attrBuilder =
+          convertParameterAttrs(function, -1, resultAttrs);
+      if (failed(attrBuilder))
+        return failure();
+      llvmFunc->addRetAttrs(*attrBuilder);
     }
 
     // Convert argument attributes.
     for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
       if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
-        llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
-        llvmArg.addAttrs(attrBuilder);
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            convertParameterAttrs(function, argIdx, argAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        llvmArg.addAttrs(*attrBuilder);
       }
     }
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0ee..0369f45ca6a0156 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
 
 gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
 }
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f68..6dc47d08fc5c812 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
   llvm.return
 }
 
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Rishi Surendran (rishisurendran)

Changes

Add support for attribute nvvm.grid_constant on LLVM function arguments. The attribute can be attached only to arguments of type llvm.ptr that have llvm.byval attribute.
Generate LLVM metadata for functions with nvvm.grid_constant arguments. The metadata node is a list of integers, where each integer n denotes that the nth parameter has the
grid_constant annotation (numbering from 1).

This patch also adds convertParameterAttr to LLVMTranslationDialectInterface for supporting the translation of derived dialect attributes on function parameters 


Full diff: https://github.com/llvm/llvm-project/pull/78228.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+13)
  • (modified) mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h (+26)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+2-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+28)
  • (modified) mlir/lib/Target/LLVMIR/AttrKindDetail.h (+13)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+57)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+32-24)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+26)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7140e614412f98..1fc5ee2c32bd49 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -59,6 +59,19 @@ def NVVM_Dialect : Dialect {
     /// Get the name of the attribute used to annotate max number of
     /// registers that can be allocated per thread.
     static StringRef getMaxnregAttrName() { return "nvvm.maxnreg"; }
+
+    /// Get the name of the attribute used to annotate kernel arguments that
+    /// are grid constants.
+    static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
+
+    /// Verify an attribute from this dialect on the argument at 'argIndex' for
+    /// the region at 'regionIndex' on the given operation. Returns failure if
+    /// the verification failed, success otherwise. This hook may optionally be
+    /// invoked from any operation containing a region.
+    LogicalResult verifyRegionArgAttribute(Operation *,
+                                           unsigned regionIndex,
+                                           unsigned argIndex,
+                                           NamedAttribute) override;
   }];
 
   let useDefaultAttributePrinterParser = 1;
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 19991a6f89d80f..55358ebc6e86ef 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 #define MLIR_TARGET_LLVMIR_LLVMTRANSLATIONINTERFACE_H
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectInterface.h"
 #include "mlir/Support/LogicalResult.h"
@@ -25,6 +26,7 @@ class IRBuilderBase;
 namespace mlir {
 namespace LLVM {
 class ModuleTranslation;
+class LLVMFuncOp;
 } // namespace LLVM
 
 /// Base class for dialect interfaces providing translation to LLVM IR.
@@ -58,6 +60,16 @@ class LLVMTranslationDialectInterface
                  LLVM::ModuleTranslation &moduleTranslation) const {
     return success();
   }
+
+  /// Hook for derived dialect interface to translate or act on a derived
+  /// dialect attribute that appears on a function parameter. This gets called
+  /// after the function operation has been translated.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attr,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    return success();
+  }
 };
 
 /// Interface collection for translation to LLVM IR, dispatches to a concrete
@@ -90,6 +102,20 @@ class LLVMTranslationInterface
     }
     return success();
   }
+
+  /// Acts on the given function operation using the interface implemented by
+  /// the dialect of one of the function parameter attributes.
+  virtual LogicalResult
+  convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
+                       NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const {
+    if (const LLVMTranslationDialectInterface *iface =
+            getInterfaceFor(attribute.getNameDialect())) {
+      return iface->convertParameterAttr(function, argIdx, attribute,
+                                         moduleTranslation);
+    }
+    return success();
+  }
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index d6b03aca28d24d..f0012bf875511e 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -326,8 +326,8 @@ class ModuleTranslation {
   convertDialectAttributes(Operation *op,
                            ArrayRef<llvm::Instruction *> instructions);
 
-  /// Translates parameter attributes and adds them to the returned AttrBuilder.
-  llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
+  FailureOr<llvm::AttrBuilder>
+  convertParameterAttrs(LLVMFuncOp func, int argIdx, DictionaryAttr paramAttrs);
 
   /// Original and translated module.
   Operation *mlirModule;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index aa49c4dc31fbc0..dc7816318131e4 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1077,6 +1077,34 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
   return success();
 }
 
+LogicalResult NVVMDialect::verifyRegionArgAttribute(Operation *op,
+                                                    unsigned regionIndex,
+                                                    unsigned argIndex,
+                                                    NamedAttribute argAttr) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return success();
+
+  bool isKernel = op->hasAttr(NVVMDialect::getKernelFuncAttrName());
+  auto attrName = argAttr.getName();
+  if (attrName == NVVM::NVVMDialect::getGridConstantAttrName()) {
+    if (!isKernel)
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute must be present only on kernel arguments.";
+    if (!llvm::isa<UnitAttr>(argAttr.getValue()))
+      return op->emitError()
+             << "'" << attrName << "' must be a unit attribute.";
+    if (!funcOp.getArgAttr(argIndex, LLVM::LLVMDialect::getByValAttrName()))
+      return op->emitError()
+             << "'" << attrName
+             << "' attribute requires the argument to also have attribute '"
+             << LLVM::LLVMDialect::getByValAttrName() << "'.";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/AttrKindDetail.h b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
index 7f81777886f56e..55a364856bd6f9 100644
--- a/mlir/lib/Target/LLVMIR/AttrKindDetail.h
+++ b/mlir/lib/Target/LLVMIR/AttrKindDetail.h
@@ -59,6 +59,19 @@ getAttrKindToNameMapping() {
   return kindNamePairs;
 }
 
+static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+getAttrNameToKindMapping() {
+  static auto attrNameToKindMapping = []() {
+    static llvm::DenseMap<llvm::StringRef, llvm::Attribute::AttrKind>
+        nameKindMap;
+    for (auto kindNamePair : getAttrKindToNameMapping()) {
+      nameKindMap.insert({kindNamePair.second, kindNamePair.first});
+    }
+    return nameKindMap;
+  }();
+  return attrNameToKindMapping;
+}
+
 } // namespace detail
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 45eb8402a7344f..5e1712527d7015 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -201,6 +201,63 @@ class NVVMDialectLLVMIRTranslationInterface
     }
     return success();
   }
+
+  LogicalResult
+  convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
+                       LLVM::ModuleTranslation &moduleTranslation) const final {
+
+    llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
+    llvm::Function *llvmFunc =
+        moduleTranslation.lookupFunction(funcOp.getName());
+    auto nvvmAnnotations =
+        moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
+
+    if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
+      llvm::MDNode *gridConstantMetaData = nullptr;
+
+      // Check if a 'grid_constant' metadata node exists for the given function
+      for (int i = nvvmAnnotations->getNumOperands() - 1; i >= 0; --i) {
+        auto opnd = nvvmAnnotations->getOperand(i);
+        if (opnd->getNumOperands() == 3 &&
+            opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
+            opnd->getOperand(1) ==
+                llvm::MDString::get(llvmContext, "grid_constant")) {
+          gridConstantMetaData = opnd;
+          break;
+        }
+      }
+
+      // 'grid_constant' is a function-level meta data node with a list of
+      // integers, where each integer n denotes that the nth parameter has the
+      // grid_constant annotation (numbering from 1). This requires aggregating
+      // the indices of the individual parameters that have this attribute.
+      llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
+      if (gridConstantMetaData == nullptr) {
+        // Create a new 'grid_constant' metadata node
+        SmallVector<llvm::Metadata *> gridConstMetadata = {
+            llvm::ValueAsMetadata::getConstant(
+                llvm::ConstantInt::get(i32, argIdx + 1))};
+        llvm::Metadata *llvmMetadata[] = {
+            llvm::ValueAsMetadata::get(llvmFunc),
+            llvm::MDString::get(llvmContext, "grid_constant"),
+            llvm::MDNode::get(llvmContext, gridConstMetadata)};
+        llvm::MDNode *llvmMetadataNode =
+            llvm::MDNode::get(llvmContext, llvmMetadata);
+        nvvmAnnotations->addOperand(llvmMetadataNode);
+      } else {
+        // Append argIdx + 1 to the 'grid_constant' argument list
+        if (auto argList =
+                dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
+          auto clonedArgList = argList->clone();
+          clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
+              llvm::ConstantInt::get(i32, argIdx + 1))));
+          gridConstantMetaData->replaceOperandWith(
+              2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
+        }
+      }
+    }
+    return success();
+  }
 };
 } // namespace
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2763a0fdd62aba..574dbfa177b9bb 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1174,28 +1174,29 @@ static void convertFunctionAttributes(LLVMFuncOp func,
   llvmFunc->setMemoryEffects(newMemEffects);
 }
 
-llvm::AttrBuilder
-ModuleTranslation::convertParameterAttrs(DictionaryAttr paramAttrs) {
+FailureOr<llvm::AttrBuilder>
+ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
+                                         DictionaryAttr paramAttrs) {
   llvm::AttrBuilder attrBuilder(llvmModule->getContext());
-
-  for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
-    Attribute attr = paramAttrs.get(mlirName);
-    // Skip attributes that are not present.
-    if (!attr)
-      continue;
-
-    // NOTE: C++17 does not support capturing structured bindings.
-    llvm::Attribute::AttrKind llvmKindCap = llvmKind;
-
-    llvm::TypeSwitch<Attribute>(attr)
-        .Case<TypeAttr>([&](auto typeAttr) {
-          attrBuilder.addTypeAttr(llvmKindCap,
-                                  convertType(typeAttr.getValue()));
-        })
-        .Case<IntegerAttr>([&](auto intAttr) {
-          attrBuilder.addRawIntAttr(llvmKindCap, intAttr.getInt());
-        })
-        .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKindCap); });
+  auto attrNameToKindMapping = getAttrNameToKindMapping();
+
+  for (auto namedAttr : paramAttrs) {
+    auto it = attrNameToKindMapping.find(namedAttr.getName());
+    if (it != attrNameToKindMapping.end()) {
+      llvm::Attribute::AttrKind llvmKind = it->second;
+
+      llvm::TypeSwitch<Attribute>(namedAttr.getValue())
+          .Case<TypeAttr>([&](auto typeAttr) {
+            attrBuilder.addTypeAttr(llvmKind, convertType(typeAttr.getValue()));
+          })
+          .Case<IntegerAttr>([&](auto intAttr) {
+            attrBuilder.addRawIntAttr(llvmKind, intAttr.getInt());
+          })
+          .Case<UnitAttr>([&](auto) { attrBuilder.addAttribute(llvmKind); });
+    } else if (namedAttr.getNameDialect()) {
+      if (failed(iface.convertParameterAttr(func, argIdx, namedAttr, *this)))
+        return failure();
+    }
   }
 
   return attrBuilder;
@@ -1224,14 +1225,21 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     // Convert result attributes.
     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
       DictionaryAttr resultAttrs = cast<DictionaryAttr>(allResultAttrs[0]);
-      llvmFunc->addRetAttrs(convertParameterAttrs(resultAttrs));
+      FailureOr<llvm::AttrBuilder> attrBuilder =
+          convertParameterAttrs(function, -1, resultAttrs);
+      if (failed(attrBuilder))
+        return failure();
+      llvmFunc->addRetAttrs(*attrBuilder);
     }
 
     // Convert argument attributes.
     for (auto [argIdx, llvmArg] : llvm::enumerate(llvmFunc->args())) {
       if (DictionaryAttr argAttrs = function.getArgAttrDict(argIdx)) {
-        llvm::AttrBuilder attrBuilder = convertParameterAttrs(argAttrs);
-        llvmArg.addAttrs(attrBuilder);
+        FailureOr<llvm::AttrBuilder> attrBuilder =
+            convertParameterAttrs(function, argIdx, argAttrs);
+        if (failed(attrBuilder))
+          return failure();
+        llvmArg.addAttrs(*attrBuilder);
       }
     }
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index ce483ddab22a0e..0369f45ca6a015 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -472,3 +472,29 @@ gpu.module @module_1 [#nvvm.target<chip = "sm_90", features = "+ptx70", link = [
 
 gpu.module @module_2 [#nvvm.target<chip = "sm_90">, #nvvm.target<chip = "sm_80">, #nvvm.target<chip = "sm_70">] {
 }
+
+// CHECK-LABEL : nvvm.grid_constant
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute must be present only on kernel arguments}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' attribute requires the argument to also have attribute 'llvm.byval'}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+
+// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e3524a848f6..6dc47d08fc5c81 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -518,3 +518,20 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4
   llvm.return
 }
 
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}
+
+// -----
+// CHECK: !nvvm.annotations =
+// CHECK: !1 = !{ptr @kernel_func, !"grid_constant", !2}
+// CHECK: !2 = !{i32 1, i32 3}
+// CHECK: !3 = !{ptr @kernel_func, !"kernel", i32 1}
+llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant}, %arg1: f32, %arg2: !llvm.ptr {llvm.byval = f32, nvvm.grid_constant}) attributes {nvvm.kernel} {
+  llvm.return
+}

return iface->convertParameterAttr(function, argIdx, attribute,
moduleTranslation);
}
return success();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure the default here should be success. This would be dropping any attribute from a dialect without interface on the floor without warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will modify it to return failure. I followed what amendOperation was doing for op attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning failure causes several test failures. There are dialect attributes like 'fir.bindc_name' which doesn't require any handling here. I updated it to emit a warning.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unresolving, @ftynse should take another look here and acknowledge the solution explicitly)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the analogy is fair enough.

/// dialect attribute that appears on a function parameter. This gets called
/// after the function operation has been translated.
virtual LogicalResult
convertParameterAttr(LLVM::LLVMFuncOp function, int argIdx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of LLVM uses unsigned for indexes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you referring to?
I strongly believe in using int in absence of bit/mask operation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ git grep " int " | wc -l
1161
$ git grep " unsigned " | wc -l
3449

also most of the former is in tests or tools.

That being said, I'm not a proponent of using unsigned (I'd rather use int32/64_t throughout) , but I am a proponent of consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you suggest to fix this: I see it as a bug and I am consistently using signed int everywhere!
Should I sent patched updating every file for consistency or should we make sure new code use safer arithmetic patterns?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seems like we'll be able to make anything a LLVM-wide policy here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR side we could. I'm rather pro not making it worse though (e.g., I'd consider consistency when it's towards direction we'd want to go rather than where we ended up in undesirable position and keeping going down that route)

@ftynse ftynse requested review from gysit and Dinistro January 16, 2024 09:41
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. I am not clear how it is used grid_constant annotation? Is there PR in LLVM?


/// Get the name of the attribute used to annotate kernel arguments that
/// are grid constants.
static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an attribute for a kernel parameter, while other attributes are for the kernel itself. We might need to split them later

@rishisurendran
Copy link
Contributor Author

Thanks for the contribution. I am not clear how it is used grid_constant annotation? Is there PR in LLVM?

There are no changes in LLVM. I have updated the description. The generated metadata will be handled by libNVVM.

@grypp
Copy link
Member

grypp commented Jan 17, 2024

Thanks for the contribution. I am not clear how it is used grid_constant annotation? Is there PR in LLVM?

There are no changes in LLVM. I have updated the description. The generated metadata will be handled by libNVVM.

Thanks for the explanation. It is useful to have this is supported by libNVVM, and frankly I am okay with this PR. But it would be nice to agree the attribute with open-source llvm folks.

@joker-eph
Copy link
Collaborator

This is something that LLVM will support eventually as well: it is part of the NVVM specification and it is needed to support the grid_constant CUDA C++ feature.

@rishisurendran
Copy link
Contributor Author

Thanks for the contribution. I am not clear how it is used grid_constant annotation? Is there PR in LLVM?

There are no changes in LLVM. I have updated the description. The generated metadata will be handled by libNVVM.

Thanks for the explanation. It is useful to have this is supported by libNVVM, and frankly I am okay with this PR. But it would be nice to agree the attribute with open-source llvm folks.

We're currently focusing on adding support for MLIR to cover the NVVM IR specification, of course LLVM support should also be added. We haven’t settled on a roadmap for this on our side just yet though, maybe others will pick this up in the meantime.

@rishisurendran
Copy link
Contributor Author

@ftynse @grypp I have addressed the feedback. Please review.

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me even though open-source llvm does not have this metadata. But I would wait @ftynse to review once again

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for me.

return iface->convertParameterAttr(function, argIdx, attribute,
moduleTranslation);
}
return success();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the analogy is fair enough.

@joker-eph joker-eph merged commit fa6850a into llvm:main Feb 12, 2024
vzakhari added a commit to vzakhari/llvm-project that referenced this pull request Feb 13, 2024
Register the LLVM IR translation interface for FIR to avoid
warnings about "Unhandled parameter attribute" after llvm#78228.
vzakhari added a commit that referenced this pull request Feb 13, 2024
Register the LLVM IR translation interface for FIR to avoid
warnings about "Unhandled parameter attribute" after #78228.
copybara-service bot pushed a commit to jax-ml/jax-triton that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to jax-ml/jax-triton that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to jax-ml/jax-triton that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to jax-ml/jax-triton that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607293980
copybara-service bot pushed a commit to jax-ml/jax-triton that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607348584
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607348584
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Feb 15, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607348584
gflegar added a commit to openxla/triton that referenced this pull request Feb 16, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
gflegar added a commit to openxla/triton that referenced this pull request Feb 16, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
gflegar added a commit to openxla/triton that referenced this pull request Feb 16, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
gflegar added a commit to openxla/triton that referenced this pull request Feb 19, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.
ThomasRaoux pushed a commit to triton-lang/triton that referenced this pull request Feb 19, 2024
… is set (#3147)

This prevents crashes in test_core.py due to too many diagnostics
emitted in llvm/llvm-project#78228 It should
also speed up compile times, as we can use multithreading, and avoid
handling diagnostic messages.
Jokeren pushed a commit to triton-lang/triton that referenced this pull request Feb 19, 2024
… is set (#3147)

This prevents crashes in test_core.py due to too many diagnostics
emitted in llvm/llvm-project#78228 It should
also speed up compile times, as we can use multithreading, and avoid
handling diagnostic messages.
rahulbatra85 pushed a commit to ROCm/xla that referenced this pull request Mar 1, 2024
This prevents crashes in test_core.py due to too many diagnostics emitted in llvm/llvm-project#78228
It should also speed up compile times, as we can use multithreading, and avoid handling diagnostic messages.

PiperOrigin-RevId: 607348584
binarman pushed a commit to binarman/triton that referenced this pull request Apr 2, 2024
… is set (triton-lang#3147)

This prevents crashes in test_core.py due to too many diagnostics
emitted in llvm/llvm-project#78228 It should
also speed up compile times, as we can use multithreading, and avoid
handling diagnostic messages.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants