Skip to content

[mlir][EmitC] Introduce a CExpression trait #84177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_EMITC_IR_EMITC_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down
39 changes: 18 additions & 21 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}

// EmitC OpTrait
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
let description = [{
With the `add` operation the arithmetic operator + (addition) can
Expand All @@ -74,7 +77,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let hasVerifier = 1;
}

def EmitC_ApplyOp : EmitC_Op<"apply", []> {
def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
let summary = "Apply operation";
let description = [{
With the `apply` operation the operators & (address of) and * (contents of)
Expand Down Expand Up @@ -211,7 +214,7 @@ def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
}];
}

def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
let summary = "Opaque call operation";
let description = [{
The `call_opaque` operation represents a C++ function call. The callee
Expand Down Expand Up @@ -257,10 +260,10 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
let hasVerifier = 1;
}

def EmitC_CastOp : EmitC_Op<"cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape
]> {
def EmitC_CastOp : EmitC_Op<"cast",
[CExpression,
DeclareOpInterfaceMethods<CastOpInterface>,
SameOperandsAndResultShape]> {
let summary = "Cast operation";
let description = [{
The `cast` operation performs an explicit type conversion and is emitted
Expand All @@ -284,7 +287,7 @@ def EmitC_CastOp : EmitC_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
let summary = "Comparison operation";
let description = [{
With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=>
Expand Down Expand Up @@ -355,7 +358,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
let hasVerifier = 1;
}

def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> {
let summary = "Division operation";
let description = [{
With the `div` operation the arithmetic operator / (division) can
Expand Down Expand Up @@ -409,9 +412,8 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
int32_t v7 = foo(v1 + v2) * (v3 + v4);
```

The operations allowed within expression body are `emitc.add`,
`emitc.apply`, `emitc.call_opaque`, `emitc.cast`, `emitc.cmp`, `emitc.div`,
`emitc.mul`, `emitc.rem`, and `emitc.sub`.
The operations allowed within expression body are EmitC operations with the
CExpression trait.

When specified, the optional `do_not_inline` indicates that the expression is
to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
Expand All @@ -427,14 +429,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";

let extraClassDeclaration = [{
static bool isCExpression(Operation &op) {
return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOpaqueOp,
emitc::CastOp, emitc::CmpOp, emitc::DivOp, emitc::MulOp,
emitc::RemOp, emitc::SubOp>(op);
}
bool hasSideEffects() {
auto predicate = [](Operation &op) {
assert(isCExpression(op) && "Expected a C expression");
assert(op.hasTrait<OpTrait::emitc::CExpression>() && "Expected a C expression");
// Conservatively assume calls to read and write memory.
if (isa<emitc::CallOpaqueOp>(op))
return true;
Expand Down Expand Up @@ -837,7 +834,7 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
let assemblyFormat = "operands attr-dict `:` type(operands)";
}

def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
let summary = "Multiplication operation";
let description = [{
With the `mul` operation the arithmetic operator * (multiplication) can
Expand All @@ -861,7 +858,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}

def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> {
let summary = "Remainder operation";
let description = [{
With the `rem` operation the arithmetic operator % (remainder) can
Expand All @@ -883,7 +880,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
let results = (outs IntegerIndexOrOpaqueType);
}

def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
let summary = "Subtraction operation";
let description = [{
With the `sub` operation the arithmetic operator - (subtraction) can
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares C++ classes for some of the traits used in the EmitC
// dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H

#include "mlir/IR/OpDefinition.h"

namespace mlir {
namespace OpTrait {
namespace emitc {

template <typename ConcreteType>
class CExpression : public TraitBase<ConcreteType, CExpression> {};

} // namespace emitc
} // namespace OpTrait
} // namespace mlir

#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -244,7 +245,7 @@ LogicalResult ExpressionOp::verify() {
return emitOpError("requires yielded type to match return type");

for (Operation &op : region.front().without_terminator()) {
if (!isCExpression(op))
if (!op.hasTrait<OpTrait::emitc::CExpression>())
return emitOpError("contains an unsupported operation");
if (op.getNumResults() != 1)
return emitOpError("requires exactly one result for each operation");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct FormExpressionsPass
// Wrap each C operator op with an expression op.
OpBuilder builder(context);
auto matchFun = [&](Operation *op) {
if (emitc::ExpressionOp::isCExpression(*op))
if (op->hasTrait<OpTrait::emitc::CExpression>())
createExpression(op, builder);
};
rootOp->walk(matchFun);
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ namespace mlir {
namespace emitc {

ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
assert(ExpressionOp::isCExpression(*op) && "Expected a C expression");
assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
"Expected a C expression");

// Create an expression yielding the value returned by op.
assert(op->getNumResults() == 1 && "Expected exactly one result");
Expand Down