-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[CIR] Realign CIR-to-LLVM IR lowering code with incubator #129293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
We previously discussed having an mlir-tblgen utility to complete the CIRAttrVisitor implementation with all support attribute types, but when I proposed an implementation to do this, a reviewer suggested using TypeSwitch instead, and I have done that in the incubator. See llvm#126332 This change brings the TypeSwitch implementation into the upstream repo to replace the visitor class.
@llvm/pr-subscribers-clang Author: Andy Kaylor (andykaylor) ChangesThe previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator. First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch. See #126332 This was done in the incubator, and here I am bringing that implementation upstream. The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two. Full diff: https://github.com/llvm/llvm-project/pull/129293.diff 3 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
deleted file mode 100644
index bbba89cb7e3fd..0000000000000
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
+++ /dev/null
@@ -1,52 +0,0 @@
-//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the CirAttrVisitor interface.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-
-#include "clang/CIR/Dialect/IR/CIRAttrs.h"
-
-namespace cir {
-
-template <typename ImplClass, typename RetTy> class CirAttrVisitor {
-public:
- // FIXME: Create a TableGen list to automatically handle new attributes
- RetTy visit(mlir::Attribute attr) {
- if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
- return getImpl().visitCirIntAttr(intAttr);
- if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
- return getImpl().visitCirFPAttr(fltAttr);
- if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
- return getImpl().visitCirConstPtrAttr(ptrAttr);
- llvm_unreachable("unhandled attribute type");
- }
-
- // If the implementation chooses not to implement a certain visit
- // method, fall back to the parent.
- RetTy visitCirIntAttr(cir::IntAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirFPAttr(cir::FPAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
-
- RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }
-
- ImplClass &getImpl() { return *static_cast<ImplClass *>(this); }
-};
-
-} // namespace cir
-
-#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index ba7fab2865116..5d083efcdda6f 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -24,10 +24,10 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/MissingFeatures.h"
#include "clang/CIR/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TimeProfiler.h"
@@ -37,41 +37,23 @@ using namespace llvm;
namespace cir {
namespace direct {
-class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
+class CIRAttrToValue {
public:
CIRAttrToValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter)
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
- mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
-
- mlir::Value visitCirIntAttr(cir::IntAttr intAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(intAttr.getType()), intAttr.getValue());
- }
-
- mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+ mlir::Value visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Value(); });
}
- mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
- mlir::Location loc = parentOp->getLoc();
- if (ptrAttr.isNullValue()) {
- return rewriter.create<mlir::LLVM::ZeroOp>(
- loc, converter->convertType(ptrAttr.getType()));
- }
- mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
- mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
- loc,
- rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
- ptrAttr.getValue().getInt());
- return rewriter.create<mlir::LLVM::IntToPtrOp>(
- loc, converter->convertType(ptrAttr.getType()), ptrVal);
- }
+ mlir::Value visitCirAttr(cir::IntAttr intAttr);
+ mlir::Value visitCirAttr(cir::FPAttr fltAttr);
+ mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
private:
mlir::Operation *parentOp;
@@ -79,21 +61,54 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
const mlir::TypeConverter *converter;
};
+/// IntAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(intAttr.getType()), intAttr.getValue());
+}
+
+/// ConstPtrAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ if (ptrAttr.isNullValue()) {
+ return rewriter.create<mlir::LLVM::ZeroOp>(
+ loc, converter->convertType(ptrAttr.getType()));
+ }
+ mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
+ mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
+ ptrAttr.getValue().getInt());
+ return rewriter.create<mlir::LLVM::IntToPtrOp>(
+ loc, converter->convertType(ptrAttr.getType()), ptrVal);
+}
+
+/// FPAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+}
+
// This class handles rewriting initializer attributes for types that do not
// require region initialization.
-class GlobalInitAttrRewriter
- : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
+class GlobalInitAttrRewriter {
public:
GlobalInitAttrRewriter(mlir::Type type,
mlir::ConversionPatternRewriter &rewriter)
: llvmType(type), rewriter(rewriter) {}
- mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
+ mlir::Attribute visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
+ .Case<cir::IntAttr, cir::FPAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Attribute(); });
+ }
- mlir::Attribute visitCirIntAttr(cir::IntAttr attr) {
+ mlir::Attribute visitCirAttr(cir::IntAttr attr) {
return rewriter.getIntegerAttr(llvmType, attr.getValue());
}
- mlir::Attribute visitCirFPAttr(cir::FPAttr attr) {
+ mlir::Attribute visitCirAttr(cir::FPAttr attr) {
return rewriter.getFloatAttr(llvmType, attr.getValue());
}
@@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
};
-bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
- mlir::Attribute attr) const {
- // There will be more cases added later.
- return isa<cir::ConstPtrAttr>(attr);
-}
-
/// Replace CIR global with a region initialized LLVM global and update
/// insertion point to the end of the initializer block.
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
// to the appropriate value.
const mlir::Location loc = op.getLoc();
setupRegionInitializedLLVMGlobalOp(op, rewriter);
- CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
- mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
+ CIRAttrToValue valueConverter(op, rewriter, typeConverter);
+ mlir::Value value = valueConverter.visit(init);
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
return mlir::success();
}
@@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
std::optional<mlir::Attribute> init = op.getInitialValue();
- // If we have an initializer and it requires region initialization, handle
- // that separately
- if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
- return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
- }
-
// Fetch required values to create LLVM op.
const mlir::Type cirSymType = op.getSymType();
@@ -218,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
SmallVector<mlir::NamedAttribute> attributes;
if (init.has_value()) {
- GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
- init = initRewriter.rewriteInitAttr(init.value());
- // If initRewriter returned a null attribute, init will have a value but
- // the value will be null. If that happens, initRewriter didn't handle the
- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
- if (!init.value()) {
+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
+ // If a directly equivalent attribute is available, use it.
+ init =
+ llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
+ .Case<cir::FPAttr>([&](cir::FPAttr attr) {
+ return rewriter.getFloatAttr(llvmType, attr.getValue());
+ })
+ .Case<cir::IntAttr>([&](cir::IntAttr attr) {
+ return rewriter.getIntegerAttr(llvmType, attr.getValue());
+ })
+ .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+ // If initRewriter returned a null attribute, init will have a value but
+ // the value will be null.
+ if (!init.value()) {
+ op.emitError() << "unsupported initializer '" << init.value() << "'";
+ return mlir::failure();
+ }
+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
+ // should be updated. For now, we use a custom op to initialize globals
+ // to the appropriate value.
+ return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
+ } else {
+ // We will only get here if new initializer types are added and this
+ // code is not updated to handle them.
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index b3366c1fb9337..d1109bb7e1c08 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
private:
- bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
-
mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const;
|
@llvm/pr-subscribers-clangir Author: Andy Kaylor (andykaylor) ChangesThe previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator. First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch. See #126332 This was done in the incubator, and here I am bringing that implementation upstream. The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two. Full diff: https://github.com/llvm/llvm-project/pull/129293.diff 3 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h b/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
deleted file mode 100644
index bbba89cb7e3fd..0000000000000
--- a/clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h
+++ /dev/null
@@ -1,52 +0,0 @@
-//===- CIRAttrVisitor.h - Visitor for CIR attributes ------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the CirAttrVisitor interface.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-#define LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
-
-#include "clang/CIR/Dialect/IR/CIRAttrs.h"
-
-namespace cir {
-
-template <typename ImplClass, typename RetTy> class CirAttrVisitor {
-public:
- // FIXME: Create a TableGen list to automatically handle new attributes
- RetTy visit(mlir::Attribute attr) {
- if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
- return getImpl().visitCirIntAttr(intAttr);
- if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
- return getImpl().visitCirFPAttr(fltAttr);
- if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
- return getImpl().visitCirConstPtrAttr(ptrAttr);
- llvm_unreachable("unhandled attribute type");
- }
-
- // If the implementation chooses not to implement a certain visit
- // method, fall back to the parent.
- RetTy visitCirIntAttr(cir::IntAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirFPAttr(cir::FPAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
- RetTy visitCirConstPtrAttr(cir::ConstPtrAttr attr) {
- return getImpl().visitCirAttr(attr);
- }
-
- RetTy visitCirAttr(mlir::Attribute attr) { return RetTy(); }
-
- ImplClass &getImpl() { return *static_cast<ImplClass *>(this); }
-};
-
-} // namespace cir
-
-#endif // LLVM_CLANG_CIR_DIALECT_IR_CIRATTRVISITOR_H
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index ba7fab2865116..5d083efcdda6f 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -24,10 +24,10 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/MissingFeatures.h"
#include "clang/CIR/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/TimeProfiler.h"
@@ -37,41 +37,23 @@ using namespace llvm;
namespace cir {
namespace direct {
-class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
+class CIRAttrToValue {
public:
CIRAttrToValue(mlir::Operation *parentOp,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter)
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
- mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
-
- mlir::Value visitCirIntAttr(cir::IntAttr intAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(intAttr.getType()), intAttr.getValue());
- }
-
- mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) {
- mlir::Location loc = parentOp->getLoc();
- return rewriter.create<mlir::LLVM::ConstantOp>(
- loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+ mlir::Value visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
+ .Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Value(); });
}
- mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
- mlir::Location loc = parentOp->getLoc();
- if (ptrAttr.isNullValue()) {
- return rewriter.create<mlir::LLVM::ZeroOp>(
- loc, converter->convertType(ptrAttr.getType()));
- }
- mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
- mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
- loc,
- rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
- ptrAttr.getValue().getInt());
- return rewriter.create<mlir::LLVM::IntToPtrOp>(
- loc, converter->convertType(ptrAttr.getType()), ptrVal);
- }
+ mlir::Value visitCirAttr(cir::IntAttr intAttr);
+ mlir::Value visitCirAttr(cir::FPAttr fltAttr);
+ mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
private:
mlir::Operation *parentOp;
@@ -79,21 +61,54 @@ class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
const mlir::TypeConverter *converter;
};
+/// IntAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(intAttr.getType()), intAttr.getValue());
+}
+
+/// ConstPtrAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ if (ptrAttr.isNullValue()) {
+ return rewriter.create<mlir::LLVM::ZeroOp>(
+ loc, converter->convertType(ptrAttr.getType()));
+ }
+ mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
+ mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
+ ptrAttr.getValue().getInt());
+ return rewriter.create<mlir::LLVM::IntToPtrOp>(
+ loc, converter->convertType(ptrAttr.getType()), ptrVal);
+}
+
+/// FPAttr visitor.
+mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
+ mlir::Location loc = parentOp->getLoc();
+ return rewriter.create<mlir::LLVM::ConstantOp>(
+ loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
+}
+
// This class handles rewriting initializer attributes for types that do not
// require region initialization.
-class GlobalInitAttrRewriter
- : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
+class GlobalInitAttrRewriter {
public:
GlobalInitAttrRewriter(mlir::Type type,
mlir::ConversionPatternRewriter &rewriter)
: llvmType(type), rewriter(rewriter) {}
- mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
+ mlir::Attribute visit(mlir::Attribute attr) {
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
+ .Case<cir::IntAttr, cir::FPAttr>(
+ [&](auto attrT) { return visitCirAttr(attrT); })
+ .Default([&](auto attrT) { return mlir::Attribute(); });
+ }
- mlir::Attribute visitCirIntAttr(cir::IntAttr attr) {
+ mlir::Attribute visitCirAttr(cir::IntAttr attr) {
return rewriter.getIntegerAttr(llvmType, attr.getValue());
}
- mlir::Attribute visitCirFPAttr(cir::FPAttr attr) {
+ mlir::Attribute visitCirAttr(cir::FPAttr attr) {
return rewriter.getFloatAttr(llvmType, attr.getValue());
}
@@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
};
-bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
- mlir::Attribute attr) const {
- // There will be more cases added later.
- return isa<cir::ConstPtrAttr>(attr);
-}
-
/// Replace CIR global with a region initialized LLVM global and update
/// insertion point to the end of the initializer block.
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
// to the appropriate value.
const mlir::Location loc = op.getLoc();
setupRegionInitializedLLVMGlobalOp(op, rewriter);
- CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
- mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
+ CIRAttrToValue valueConverter(op, rewriter, typeConverter);
+ mlir::Value value = valueConverter.visit(init);
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
return mlir::success();
}
@@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
std::optional<mlir::Attribute> init = op.getInitialValue();
- // If we have an initializer and it requires region initialization, handle
- // that separately
- if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
- return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
- }
-
// Fetch required values to create LLVM op.
const mlir::Type cirSymType = op.getSymType();
@@ -218,12 +221,31 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
SmallVector<mlir::NamedAttribute> attributes;
if (init.has_value()) {
- GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
- init = initRewriter.rewriteInitAttr(init.value());
- // If initRewriter returned a null attribute, init will have a value but
- // the value will be null. If that happens, initRewriter didn't handle the
- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
- if (!init.value()) {
+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
+ // If a directly equivalent attribute is available, use it.
+ init =
+ llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(init.value())
+ .Case<cir::FPAttr>([&](cir::FPAttr attr) {
+ return rewriter.getFloatAttr(llvmType, attr.getValue());
+ })
+ .Case<cir::IntAttr>([&](cir::IntAttr attr) {
+ return rewriter.getIntegerAttr(llvmType, attr.getValue());
+ })
+ .Default([&](mlir::Attribute attr) { return mlir::Attribute(); });
+ // If initRewriter returned a null attribute, init will have a value but
+ // the value will be null.
+ if (!init.value()) {
+ op.emitError() << "unsupported initializer '" << init.value() << "'";
+ return mlir::failure();
+ }
+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
+ // should be updated. For now, we use a custom op to initialize globals
+ // to the appropriate value.
+ return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
+ } else {
+ // We will only get here if new initializer types are added and this
+ // code is not updated to handle them.
op.emitError() << "unsupported initializer '" << init.value() << "'";
return mlir::failure();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index b3366c1fb9337..d1109bb7e1c08 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
mlir::ConversionPatternRewriter &rewriter) const override;
private:
- bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
-
mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
cir::GlobalOp op, mlir::Attribute init,
mlir::ConversionPatternRewriter &rewriter) const;
|
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); | ||
mlir::Value visit(mlir::Attribute attr) { | ||
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) | ||
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its a shame this is such a manual list/sub-par interface here :/ it would be way neater if it was able to deduce these. sigh
That said, I think this is quite a bit of a nicer interface, so I'm ok with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it needs to be a manual list because it's never going to be the complete list. These are just the types we intend to handle here, with other types going to the Default() handler.
// the value will be null. If that happens, initRewriter didn't handle the | ||
// attribute type. It probably needs to be added to GlobalInitAttrRewriter. | ||
if (!init.value()) { | ||
if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section is a bit worse unfortunately. We might consider extracting this type at one point.
Why did we remove the use of GlobalInitAttrRewriter
(or am I missing a use elsewhere?), and do we intend to bring it back? Could it be used here instead like it was before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GlobalInitAttrRewriter
would have used a TypeSwitch
for its internal implementation, and since each type being handled/visited is just a single line I thought it was just as clean to put it inline here. I can move it back to using a separate class if you like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just asking for 1 or the other. IF we don't need GlobalInitAttrRewriter
(It looks like we don't as you've removed its usage?) we should remove it, not update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, I see that I did both -- I modified GlobalInitAttrRewriter
to use TypeSwitch
but I rewrote this code to not use it. I guess that was an incomplete refactoring, but the good new is that it will make it easy to update it as you've suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, cool :D I was beginning to think I was missing something, like another use of it! I don't mind which you do (remove it and leave inline, or switch to the type). There is a 'line' I suspect where 'leave inline' becomes too big/annoying and we want to move it, but we are far from it at the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The incubator only has three types that get handled here, but it does look a bit cleaner with these being handled separately. There is a possibility that the if
above could get out of sync with the types handled in GlobalInitAttrRewriter
but that would trigger the error below, so we'd catch it in development.
The previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator. First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch. See llvm#126332 This was done in the incubator, and here I am bringing that implementation upstream. The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two.
The previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator.
First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch.
See #126332
This was done in the incubator, and here I am bringing that implementation upstream.
The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two.