Skip to content

Backported all emitc-related commits from upstream #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions llvm/include/llvm/Support/Casting.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,52 @@ template <class X, class Y>
return unique_dyn_cast_or_null<X, Y>(Val);
}

//===----------------------------------------------------------------------===//
// Isa Predicates
//===----------------------------------------------------------------------===//

/// These are wrappers over isa* function that allow them to be used in generic
/// algorithms such as `llvm:all_of`, `llvm::none_of`, etc. This is accomplished
/// by exposing the isa* functions through function objects with a generic
/// function call operator.

namespace detail {
template <typename... Types> struct IsaCheckPredicate {
template <typename T> [[nodiscard]] bool operator()(const T &Val) const {
return isa<Types...>(Val);
}
};

template <typename... Types> struct IsaAndPresentCheckPredicate {
template <typename T> [[nodiscard]] bool operator()(const T &Val) const {
return isa_and_present<Types...>(Val);
}
};
} // namespace detail

/// Function object wrapper for the `llvm::isa` type check. The function call
/// operator returns true when the value can be cast to any type in `Types`.
/// Example:
/// ```
/// SmallVector<Type> myTypes = ...;
/// if (llvm::all_of(myTypes, llvm::IsaPred<VectorType>))
/// ...
/// ```
template <typename... Types>
inline constexpr detail::IsaCheckPredicate<Types...> IsaPred{};

/// Function object wrapper for the `llvm::isa_and_present` type check. The
/// function call operator returns true when the value can be cast to any type
/// in `Types`, or if the value is not present (e.g., nullptr). Example:
/// ```
/// SmallVector<Type> myTypes = ...;
/// if (llvm::all_of(myTypes, llvm::IsaAndPresentPred<VectorType>))
/// ...
/// ```
template <typename... Types>
inline constexpr detail::IsaAndPresentCheckPredicate<Types...>
IsaAndPresentPred{};

} // end namespace llvm

#endif // LLVM_SUPPORT_CASTING_H
15 changes: 15 additions & 0 deletions llvm/unittests/Support/Casting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,21 @@ TEST(CastingTest, dyn_cast_if_present) {
EXPECT_FALSE(t4.hasValue);
}

TEST(CastingTest, isa_check_predicates) {
auto IsaFoo = IsaPred<foo>;
EXPECT_TRUE(IsaFoo(B1));
EXPECT_TRUE(IsaFoo(B2));
EXPECT_TRUE(IsaFoo(B3));
EXPECT_TRUE(IsaPred<foo>(B4));
EXPECT_TRUE((IsaPred<foo, bar>(B4)));

auto IsaAndPresentFoo = IsaAndPresentPred<foo>;
EXPECT_TRUE(IsaAndPresentFoo(B2));
EXPECT_TRUE(IsaAndPresentFoo(B4));
EXPECT_FALSE(IsaAndPresentPred<foo>(fub()));
EXPECT_FALSE((IsaAndPresentPred<foo, bar>(fub())));
}

std::unique_ptr<derived> newd() { return std::make_unique<derived>(); }
std::unique_ptr<base> newb() { return std::make_unique<derived>(); }

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);

/// Determines whether \p type is valid in EmitC.
bool isSupportedEmitCType(mlir::Type type);

/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);

/// Determines whether \p type is integer like, i.e. it's a supported integer,
/// an index or opaque type.
bool isIntegerIndexOrOpaqueType(Type type);

/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
Expand Down
60 changes: 32 additions & 28 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,8 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
%0 = "emitc.constant"(){value = 42 : i32} : () -> i32

// Constant emitted as `char = CHAR_MIN;`
%1 = "emitc.constant"()
{value = #emitc.opaque<"CHAR_MIN"> : !emitc.opaque<"char">}
: () -> !emitc.opaque<"char">
%1 = "emitc.constant"() {value = #emitc.opaque<"CHAR_MIN">}
: () -> !emitc.opaque<"char">
```
}];

Expand Down Expand Up @@ -992,9 +991,8 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32

// Variable emitted as `int32_t* = NULL;`
%1 = "emitc.variable"()
{value = #emitc.opaque<"NULL"> : !emitc.opaque<"int32_t*">}
: () -> !emitc.opaque<"int32_t*">
%1 = "emitc.variable"() {value = #emitc.opaque<"NULL">}
: () -> !emitc.ptr<!emitc.opaque<"int32_t">>
```

Since folding is not supported, it can be used with pointers.
Expand Down Expand Up @@ -1022,12 +1020,12 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
The `emitc.global` operation declares or defines a named global variable.
The backing memory for the variable is allocated statically and is
described by the type of the variable.
Optionally, and `initial_value` can be provided.
Internal linkage can be specified using the `staticSpecifier` unit attribute
and external linkage can be specified using the `externSpecifier` unit attribute.
Optionally, an `initial_value` can be provided.
Internal linkage can be specified using the `static_specifier` unit attribute
and external linkage can be specified using the `extern_specifier` unit attribute.
Note that the default linkage without those two keywords depends on whether
the target is C or C++ and whether the global variable is `const`.
The global variable can also be marked constant using the `constSpecifier`
The global variable can also be marked constant using the `const_specifier`
unit attribute. Writing to such constant global variables is
undefined.

Expand All @@ -1049,14 +1047,14 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttr:$type,
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
UnitAttr:$externSpecifier,
UnitAttr:$staticSpecifier,
UnitAttr:$constSpecifier);
UnitAttr:$extern_specifier,
UnitAttr:$static_specifier,
UnitAttr:$const_specifier);

let assemblyFormat = [{
(`extern` $externSpecifier^)?
(`static` $staticSpecifier^)?
(`const` $constSpecifier^)?
(`extern` $extern_specifier^)?
(`static` $static_specifier^)?
(`const` $const_specifier^)?
$sym_name
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
attr-dict
Expand Down Expand Up @@ -1224,35 +1222,41 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}

def EmitC_SubscriptOp : EmitC_Op<"subscript",
[TypesMatchWith<"result type matches element type of 'array'",
"array", "result",
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
let summary = "Array subscript operation";
def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let summary = "Subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
to variables or arguments of array type.
to variables or arguments of array, pointer and opaque type.

Example:

```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
%0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
%1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<IntegerIndexOrOpaqueType>:$indices);
let arguments = (ins Arg<AnyTypeOf<[
EmitC_ArrayType,
EmitC_OpaqueType,
EmitC_PointerType]>,
"the value to subscript">:$value,
Variadic<EmitCType>:$indices);
let results = (outs EmitCType:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
build($_builder, $_state, array.getType().getElementType(), array, indices);
}]>,
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
build($_builder, $_state, pointer.getType().getPointee(), pointer,
ValueRange{index});
}]>
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}


Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Target/Cpp/CppEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef MLIR_TARGET_CPP_CPPEMITTER_H
#define MLIR_TARGET_CPP_CPPEMITTER_H

#include "llvm/Support/raw_ostream.h"
#include "mlir/Support/LLVM.h"

namespace mlir {
struct LogicalResult;
Expand Down
55 changes: 52 additions & 3 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,55 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
}
};

template <typename ArithOp, typename EmitCOp>
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer type");
}

if (type.isInteger(1)) {
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
arith::IntegerOverflowFlags::nsw)) {
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
// we compute in unsigned integers to avoid UB.
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Value result = rewriter.template create<EmitCOp>(op.getLoc(),
arithmeticType, lhs, rhs);

if (arithmeticType != type) {
result =
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
}
rewriter.replaceOp(op, result);
return success();
}
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
Expand Down Expand Up @@ -432,9 +481,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
ArithOpConversion<arith::AddIOp, emitc::AddOp>,
ArithOpConversion<arith::MulIOp, emitc::MulOp>,
ArithOpConversion<arith::SubIOp, emitc::SubOp>,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpFOpConversion,
CmpIOpConversion,
SelectOpConversion,
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace {
struct ConvertArithToEmitC
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
using Base::Base;

void runOnOperation() override;
};
} // namespace
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===//
//===- FuncToEmitCPass.cpp - Func to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
25 changes: 20 additions & 5 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,18 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(),
"only public and private visibility is currently supported");
}
// We are explicit in specifier the linkage because the default linkage
// We are explicit in specifing the linkage because the default linkage
// for constants is different in C and C++.
bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
bool externSpecifier = !staticSpecifier;

Attribute initialValue = operands.getInitialValueAttr();
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};

rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
externSpecifier, staticSpecifier, operands.getConstant());
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
staticSpecifier, operands.getConstant());
return success();
}
};
Expand Down Expand Up @@ -124,8 +128,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}

auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
op.getLoc(), arrayValue, operands.getIndices());

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
Expand All @@ -143,9 +153,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
Expand Down
Loading