Skip to content

[mlir][llvm] Add llvm.target_features features attribute #71510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -933,4 +933,68 @@ def LLVM_VScaleRangeAttr : LLVM_Attr<"VScaleRange", "vscale_range"> {
"IntegerAttr":$maxRange);
let assemblyFormat = "`<` struct(params) `>`";
}

//===----------------------------------------------------------------------===//
// TargetFeaturesAttr
//===----------------------------------------------------------------------===//

def LLVM_TargetFeaturesAttr : LLVM_Attr<"TargetFeatures", "target_features">
{
let summary = "LLVM target features attribute";

let description = [{
Represents the LLVM target features as a list that can be checked within
passes/rewrites.

Example:
```mlir
#llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
```

Then within a pass or rewrite the features active at an op can be queried:

```c++
auto targetFeatures = LLVM::TargetFeaturesAttr::featuresAt(op);

if (!targetFeatures.contains("+sme-f64f64"))
return failure();
```
}];

let parameters = (ins OptionalArrayRefParameter<"StringAttr">:$features);

let builders = [
TypeBuilder<(ins "::llvm::StringRef":$features)>,
TypeBuilder<(ins "::llvm::ArrayRef<::llvm::StringRef>":$features)>
];

let extraClassDeclaration = [{
/// Checks if a feature is contained within the features list.
/// Note: Using a StringAttr allows doing pointer-comparisons.
bool contains(::mlir::StringAttr feature) const;
bool contains(::llvm::StringRef feature) const;

bool nullOrEmpty() const {
// Checks if this attribute is null, or the features are empty.
return !bool(*this) || getFeatures().empty();
}

/// Returns the list of features as an LLVM-compatible string.
std::string getFeaturesString() const;

/// Finds the target features on the parent FunctionOpInterface.
/// Note: This assumes the attribute name matches the return value of
/// `getAttributeName()`.
static TargetFeaturesAttr featuresAt(Operation* op);

/// Canonical name for this attribute within MLIR.
static constexpr StringLiteral getAttributeName() {
return StringLiteral("target_features");
}
}];

let assemblyFormat = "`<` `[` (`]`) : ($features^ `]`)? `>`";
let genVerifyDecl = 1;
}

#endif // LLVMIR_ATTRDEFS
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
OptionalAttr<UnnamedAddr>:$unnamed_addr,
OptionalAttr<I64Attr>:$alignment,
OptionalAttr<LLVM_VScaleRangeAttr>:$vscale_range,
OptionalAttr<FramePointerKindAttr>:$frame_pointer
OptionalAttr<FramePointerKindAttr>:$frame_pointer,
OptionalAttr<LLVM_TargetFeaturesAttr>:$target_features
);

let regions = (region AnyRegion:$body);
Expand Down
65 changes: 65 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/BinaryFormat/Dwarf.h"
Expand Down Expand Up @@ -183,3 +184,67 @@ void printExpressionArg(AsmPrinter &printer, uint64_t opcode,
i++;
});
}

//===----------------------------------------------------------------------===//
// TargetFeaturesAttr
//===----------------------------------------------------------------------===//

TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
llvm::ArrayRef<StringRef> features) {
return Base::get(context,
llvm::map_to_vector(features, [&](StringRef feature) {
return StringAttr::get(context, feature);
}));
}

TargetFeaturesAttr TargetFeaturesAttr::get(MLIRContext *context,
StringRef targetFeatures) {
SmallVector<StringRef> features;
targetFeatures.split(features, ',', /*MaxSplit=*/-1,
/*KeepEmpty=*/false);
return get(context, features);
}

LogicalResult
TargetFeaturesAttr::verify(function_ref<InFlightDiagnostic()> emitError,
llvm::ArrayRef<StringAttr> features) {
for (StringAttr featureAttr : features) {
if (!featureAttr || featureAttr.empty())
return emitError() << "target features can not be null or empty";
auto feature = featureAttr.strref();
if (feature[0] != '+' && feature[0] != '-')
return emitError() << "target features must start with '+' or '-'";
if (feature.contains(','))
return emitError() << "target features can not contain ','";
}
return success();
}

bool TargetFeaturesAttr::contains(StringAttr feature) const {
if (nullOrEmpty())
return false;
// Note: Using StringAttr does pointer comparisons.
return llvm::is_contained(getFeatures(), feature);
}

bool TargetFeaturesAttr::contains(StringRef feature) const {
if (nullOrEmpty())
return false;
return llvm::is_contained(getFeatures(), feature);
}

std::string TargetFeaturesAttr::getFeaturesString() const {
std::string featuresString;
llvm::raw_string_ostream ss(featuresString);
llvm::interleave(
getFeatures(), ss, [&](auto &feature) { ss << feature.strref(); }, ",");
return ss.str();
}

TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) {
auto parentFunction = op->getParentOfType<FunctionOpInterface>();
if (!parentFunction)
return {};
return parentFunction.getOperation()->getAttrOfType<TargetFeaturesAttr>(
getAttributeName());
}
7 changes: 7 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,7 @@ static constexpr std::array ExplicitAttributes{
StringLiteral("aarch64_pstate_za_new"),
StringLiteral("vscale_range"),
StringLiteral("frame-pointer"),
StringLiteral("target-features"),
};

static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
Expand Down Expand Up @@ -1717,6 +1718,12 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
stringRefFramePointerKind)
.value()));
}

if (llvm::Attribute attr = func->getFnAttribute("target-features");
attr.isStringAttribute()) {
funcOp.setTargetFeaturesAttr(
LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
Comment on lines +1722 to +1725
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: personally would find it clearer if attr.isStringAttribute() was in a separate if, threw me off a little.

Suggested change
if (llvm::Attribute attr = func->getFnAttribute("target-features");
attr.isStringAttribute()) {
funcOp.setTargetFeaturesAttr(
LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
if (llvm::Attribute attr = func->getFnAttribute("target-features")) {
if (attr.isStringAttribute())
funcOp.setTargetFeaturesAttr(
LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's LLVM style to avoid extra nesting, so I'm going keep this :)

}
}

DictionaryAttr
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
if (func.getArmNewZa())
llvmFunc->addFnAttr("aarch64_pstate_za_new");

if (auto targetFeatures = func.getTargetFeatures())
llvmFunc->addFnAttr("target-features", targetFeatures->getFeaturesString());

if (auto attr = func.getVscaleRange())
llvmFunc->addFnAttr(llvm::Attribute::getWithVScaleRangeArgs(
getLLVMContext(), attr->getMinRange().getInt(),
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/Import/target-features.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-LABEL: llvm.func @target_features()
; CHECK-SAME: #llvm.target_features<["+sme", "+sme-f64f64", "+sve"]>
define void @target_features() #0 {
ret void
}

attributes #0 = { "target-features"="+sme,+sme-f64f64,+sve" }
21 changes: 21 additions & 0 deletions mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,27 @@ llvm.func @stepvector_intr_wrong_type() -> vector<7xf32> {

// -----

// expected-error @below{{target features can not contain ','}}
llvm.func @invalid_target_feature() attributes { target_features = #llvm.target_features<["+bad,feature", "+test"]> }
{
}

// -----

// expected-error @below{{target features must start with '+' or '-'}}
llvm.func @missing_target_feature_prefix() attributes { target_features = #llvm.target_features<["sme"]> }
{
}

// -----

// expected-error @below{{target features can not be null or empty}}
llvm.func @empty_target_feature() attributes { target_features = #llvm.target_features<["", "+sve"]> }
{
}

// -----

llvm.comdat @__llvm_comdat {
llvm.comdat_selector @foo any
}
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/target-features.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: define void @target_features
// CHECK: attributes #{{.*}} = { "target-features"="+sme,+sve,+sme-f64f64" }
llvm.func @target_features() attributes {
target_features = #llvm.target_features<["+sme", "+sve", "+sme-f64f64"]>
} {
llvm.return
}