Skip to content

[mlir][llvm] Port overflowFlags to a native operation property #89312

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2024

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Apr 18, 2024

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.

Mogball added 2 commits April 18, 2024 20:26
This is useful for defining operation properties that are enums.
This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.
@Mogball Mogball requested review from jpienaar and ftynse April 18, 2024 21:10
@Mogball Mogball requested a review from joker-eph April 18, 2024 21:10
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-ods

Author: Jeff Niu (Mogball)

Changes

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.


Patch is 21.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89312.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h (+11-11)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+9-5)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+10-6)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+29-47)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+4-4)
  • (modified) mlir/include/mlir/IR/Properties.td (+13)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+1-2)
  • (modified) mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp (-7)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+13-6)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+13-13)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+72-4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-4)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 0891e2ba7be760..7ffc8613317603 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 LLVM::IntegerOverflowFlags
 convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
 
-/// Creates an LLVM overflow attribute from a given arithmetic overflow
-/// attribute.
-LLVM::IntegerOverflowFlagsAttr
-convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
-
 /// Creates an LLVM rounding mode enum value from a given arithmetic rounding
 /// mode enum value.
 LLVM::RoundingMode
@@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
@@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
     // Get the name of the arith overflow attribute.
     StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
     // Remove the source overflow attribute.
-    auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
-        convertedAttr.erase(arithAttrName));
-    if (arithAttr) {
-      StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
-      convertedAttr.set(targetAttrName,
-                        convertArithOverflowAttrToLLVM(arithAttr));
+    if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
+            convertedAttr.erase(arithAttrName))) {
+      overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
     }
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
 
 private:
   NamedAttrList convertedAttr;
+  LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
 };
 
 template <typename SourceOp, typename TargetOp>
@@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f362167ee93249..f3bf5b66398e09 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -18,13 +19,16 @@ class CallOpInterface;
 
 namespace LLVM {
 namespace detail {
+/// Handle generically setting flags as native properties on LLVM operations.
+void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter);
+LogicalResult oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
 } // namespace LLVM
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 279175b6128fc7..964281592cc65e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter);
 
-LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter);
+LogicalResult vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 } // namespace detail
 } // namespace LLVM
 
@@ -70,6 +70,9 @@ class AttrConvertPassThrough {
   AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
 
   ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   ArrayRef<NamedAttribute> srcAttrs;
@@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     return LLVM::detail::vectorOneToOneRewrite(
         op, TargetOp::getOperationName(), adaptor.getOperands(),
-        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
+        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
+        attrConvert.getOverflowFlags());
   }
 };
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index cee752aeb269b7..7085f81e203a1e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
 def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
   let description = [{
-    Access to op integer overflow flags.
+    This interface defines an LLVM operation with integer overflow flags and
+    provides a uniform API for accessing them.
   }];
 
   let cppNamespace = "::mlir::LLVM";
 
   let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
-      /*returnType=*/  "IntegerOverflowFlagsAttr",
-      /*methodName=*/  "getOverflowAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getOverflowFlagsAttr();
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoUnsignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoSignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the IntegerOverflowFlagsAttr attribute
-                         for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getIntegerOverflowAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        return "overflowFlags";
-      }]
-      >
+    InterfaceMethod<[{
+      Get the integer overflow flags for the operation.
+    }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
+      return $_op.getProperties().overflowFlags;
+    }]>,
+    InterfaceMethod<[{
+      Set the integer overflow flags for the operation.
+    }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
+      $_op.getProperties().overflowFlags = flags;
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Unsigned Wrap keyword.
+    }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nuw);
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Signed Wrap keyword.
+    }], "bool", "hasNoSignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nsw);
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the overflow flags property.
+    }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
+      return "overflowFlags";
+    }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f8f9264b3889be..f6dca8e2338816 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -60,16 +60,16 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
     LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
     !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
   dag iofArg = (
-    ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
+    ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
   let arguments = !con(commonArgs, iofArg);
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
-    moduleImport.setIntegerOverflowFlagsAttr(inst, op);
+    moduleImport.setIntegerOverflowFlags(inst, op);
     $res = op;
   }];
   let assemblyFormat = [{
-    $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
-    custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+    $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
+    `` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
   }];
   string llvmBuilder =
     "$res = builder.Create" # instName #
diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td
index 99da1763524fa9..0babdbbfa05bc2 100644
--- a/mlir/include/mlir/IR/Properties.td
+++ b/mlir/include/mlir/IR/Properties.td
@@ -153,4 +153,17 @@ class ArrayProperty<string storageTypeParam = "", int n, string desc = ""> :
   let assignToStorage = "::llvm::copy($_value, $_storage)";
 }
 
+class EnumProperty<string storageTypeParam, string desc = ""> :
+    Property<storageTypeParam, desc> {
+  code writeToMlirBytecode = [{
+    $_writer.writeVarInt(static_cast<uint64_t>($_storage));
+  }];
+  code readFromMlirBytecode = [{
+    uint64_t val;
+    if (failed($_reader.readVarInt(val)))
+      return ::mlir::failure();
+    $_storage = static_cast<}] # storageTypeParam # [{>(val);
+  }];
+}
+
 #endif // PROPERTIES
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b551eb937cfe8d..6180d17697c271 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -183,8 +183,7 @@ class ModuleImport {
   /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
   /// operation `op` given the original instruction `inst`. Asserts if the
   /// operation does not implement the integer overflow flag interface.
-  void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
-                                   Operation *op) const;
+  void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
 
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index f12eba98480d33..cf60a048f782c6 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
   return llvmFlags;
 }
 
-LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
-    arith::IntegerOverflowFlagsAttr flagsAttr) {
-  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
-  return LLVM::IntegerOverflowFlagsAttr::get(
-      flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
-}
-
 LLVM::RoundingMode
 mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
   switch (roundingMode) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 83c31a204efc7e..1886dfa870961a 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 // Detail methods
 //===----------------------------------------------------------------------===//
 
+void LLVM::detail::setNativeProperties(Operation *op,
+                                       IntegerOverflowFlags overflowFlags) {
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+    iface.setOverflowFlags(overflowFlags);
+}
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult
-LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   unsigned numResults = op->getNumResults();
 
   SmallVector<Type> resultTypes;
@@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
                       resultTypes, targetAttrs);
 
+  setNativeProperties(newOp, overflowFlags);
+
   // If the operation produced 0 or 1 result, return them immediately.
   if (numResults == 0)
     return rewriter.eraseOp(op), success();
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 544bcc71aca1b5..626135c10a3e96 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
-LogicalResult
-LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
@@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
   auto llvmNDVectorTy = operands[0].getType();
   if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
     return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
-                           rewriter);
+                           rewriter, overflowFlags);
 
-  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
-                                                         ValueRange operands) {
-    return rewriter
-        .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                llvm1DVectorTy, targetAttrs)
-        ->getResult(0);
+  auto callback = [op, targetOp, targetAttrs, overflowFlags,
+                   &rewriter](Type llvm1DVectorTy, ValueRange operands) {
+    Operation *newOp =
+        rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
+                        operands, llvm1DVectorTy, targetAttrs);
+    LLVM::detail::setNativeProperties(newOp, overflowFlags);
+    return newOp->getResult(0);
   };
 
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index f90240a67dcc5f..84994d816ad1a1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// Property Helpers
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+
+namespace mlir {
+static Attribute convertToAttribute(MLIRContext *ctx,
+                                    IntegerOverflowFlags flags) {
+  return IntegerOverflowFlagsAttr::get(ctx, flags);
+}
+
+static LogicalResult
+convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
+                     function_ref<InFlightDiagnostic()> emitError) {
+  auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
+  if (!flagsAttr) {
+    return emitError() << "expected 'overflowFlags' attribute to be an "
+                          "IntegerOverflowFlagsAttr, but got "
+                       << attr;
+  }
+  flags = flagsAttr.getValue();
+  return success();
+}
+} // namespace mlir
+
+static ParseResult parseOverflowFlags(AsmParser &p,
+                                      IntegerOverflowFlags &flags) {
+  if (failed(p.parseOptionalKeyword("overflow"))) {
+    flags = IntegerOverflowFlags::none;
+    return success();
+  }
+  if (p.parseLess())
+    return failure();
+  do {
+    StringRef kw;
+    SMLoc loc = p.getCurrentLocation();
+    if (p.parseKeyword(&kw))
+      return failure();
+    std::optional<IntegerOverflowFlags> flag =
+        symbolizeIntegerOverflowFlags(kw);
+    if (!flag)
+      return p.emitError(loc,
+                         "invalid overflow flag: expected nsw, nuw, or none");
+    flags = flags | *flag;
+  } while (succeeded(p.parseOptionalComma()));
+  return p.parseGreater();
+}
+
+static void printOverflowFlags(AsmPrinter &p, Operation *op,
+                               IntegerOverflowFlags flags) {
+  if (flags == IntegerOverflowFlags::none)
+    return;
+  p << " overflow<";
+  SmallVector<StringRef, 2> strs;
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nsw))
+    strs.push_back("nsw");
+  if (bitEnumContainsAny(flags, IntegerOverflowFlags::nuw))
+    strs.push_back("nuw");
+  llvm::interleaveComma(strs, p);
+  p << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute Helpers
+//===----------------------------------------------------------------------===//
+
 ...
[truncated]

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup

@Mogball Mogball merged commit 0c41eea into llvm:main Apr 18, 2024
@clementval
Copy link
Contributor

clementval commented Apr 19, 2024

This is breaking most of flang buildbots. The pre-ci was showing the error.

@@ -60,16 +60,16 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is LLVM_IntegerOverflowFlagsAttr still used after this PR or could it be removed as well?

Thanks for the change!

Mogball pushed a commit to Mogball/llvm-project that referenced this pull request Apr 19, 2024
…m#89312)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants