Skip to content

[mlir][EmitC] Introduce a CExpression trait #84177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2024

Conversation

marbre
Copy link
Member

@marbre marbre commented Mar 6, 2024

This adds a CExpression trait and replaces the isCExpression() function.

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Marius Brehler (marbre)

Changes

This adds a CExpression trait and replaces the isCExpression() function.


Full diff: https://github.com/llvm/llvm-project/pull/84177.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+1)
  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+30-31)
  • (added) mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h (+30)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+1)
  • (modified) mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 3d38744527d599..1f0df3cb336b12 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_EMITC_IR_EMITC_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 6bef395e94eb9d..db0e2d10960d72 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -47,11 +47,14 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
+// EmitC OpTrait
+def CExpression : NativeOpTrait<"emitc::CExpression">;
+
 // Types only used in binary arithmetic operations.
 def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
 def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
 
-def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
+def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
   let summary = "Addition operation";
   let description = [{
     With the `add` operation the arithmetic operator + (addition) can
@@ -74,7 +77,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_ApplyOp : EmitC_Op<"apply", []> {
+def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
   let summary = "Apply operation";
   let description = [{
     With the `apply` operation the operators & (address of) and * (contents of)
@@ -103,7 +106,7 @@ def EmitC_ApplyOp : EmitC_Op<"apply", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> {
+def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", [CExpression]> {
   let summary = "Bitwise and operation";
   let description = [{
     With the `bitwise_and` operation the bitwise operator & (and) can
@@ -121,7 +124,8 @@ def EmitC_BitwiseAndOp : EmitC_BinaryOp<"bitwise_and", []> {
   }];
 }
 
-def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> {
+def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift",
+    [CExpression]> {
   let summary = "Bitwise left shift operation";
   let description = [{
     With the `bitwise_left_shift` operation the bitwise operator <<
@@ -139,7 +143,7 @@ def EmitC_BitwiseLeftShiftOp : EmitC_BinaryOp<"bitwise_left_shift", []> {
   }];
 }
 
-def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> {
+def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", [CExpression]> {
   let summary = "Bitwise not operation";
   let description = [{
     With the `bitwise_not` operation the bitwise operator ~ (not) can
@@ -157,7 +161,7 @@ def EmitC_BitwiseNotOp : EmitC_UnaryOp<"bitwise_not", []> {
   }];
 }
 
-def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> {
+def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", [CExpression]> {
   let summary = "Bitwise or operation";
   let description = [{
     With the `bitwise_or` operation the bitwise operator | (or)
@@ -175,7 +179,8 @@ def EmitC_BitwiseOrOp : EmitC_BinaryOp<"bitwise_or", []> {
   }];
 }
 
-def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> {
+def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift",
+    [CExpression]> {
   let summary = "Bitwise right shift operation";
   let description = [{
     With the `bitwise_right_shift` operation the bitwise operator >>
@@ -193,7 +198,7 @@ def EmitC_BitwiseRightShiftOp : EmitC_BinaryOp<"bitwise_right_shift", []> {
   }];
 }
 
-def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
+def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", [CExpression]> {
   let summary = "Bitwise xor operation";
   let description = [{
     With the `bitwise_xor` operation the bitwise operator ^ (xor)
@@ -211,7 +216,7 @@ def EmitC_BitwiseXorOp : EmitC_BinaryOp<"bitwise_xor", []> {
   }];
 }
 
-def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
+def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
   let summary = "Opaque call operation";
   let description = [{
     The `call_opaque` operation represents a C++ function call. The callee
@@ -257,10 +262,10 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", []> {
   let hasVerifier = 1;
 }
 
-def EmitC_CastOp : EmitC_Op<"cast", [
-    DeclareOpInterfaceMethods<CastOpInterface>,
-    SameOperandsAndResultShape
-  ]> {
+def EmitC_CastOp : EmitC_Op<"cast",
+    [CExpression,
+     DeclareOpInterfaceMethods<CastOpInterface>,
+     SameOperandsAndResultShape]> {
   let summary = "Cast operation";
   let description = [{
     The `cast` operation performs an explicit type conversion and is emitted
@@ -284,7 +289,7 @@ def EmitC_CastOp : EmitC_Op<"cast", [
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
 }
 
-def EmitC_CmpOp : EmitC_BinaryOp<"cmp", []> {
+def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
   let summary = "Comparison operation";
   let description = [{
     With the `cmp` operation the comparison operators ==, !=, <, <=, >, >=, <=> 
@@ -355,7 +360,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
   let hasVerifier = 1;
 }
 
-def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
+def EmitC_DivOp : EmitC_BinaryOp<"div", [CExpression]> {
   let summary = "Division operation";
   let description = [{
     With the `div` operation the arithmetic operator / (division) can
@@ -409,9 +414,8 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
     int32_t v7 = foo(v1 + v2) * (v3 + v4);
     ```
 
-    The operations allowed within expression body are `emitc.add`,
-    `emitc.apply`, `emitc.call_opaque`, `emitc.cast`, `emitc.cmp`, `emitc.div`,
-    `emitc.mul`, `emitc.rem`, and `emitc.sub`.
+    The operations allowed within expression body are EmitC operations with the
+    CExpression trait.
 
     When specified, the optional `do_not_inline` indicates that the expression is
     to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
@@ -427,14 +431,9 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
   let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";
 
   let extraClassDeclaration = [{
-    static bool isCExpression(Operation &op) {
-      return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOpaqueOp,
-                 emitc::CastOp, emitc::CmpOp, emitc::DivOp, emitc::MulOp,
-                 emitc::RemOp, emitc::SubOp>(op);
-    }
     bool hasSideEffects() {
       auto predicate = [](Operation &op) {
-        assert(isCExpression(op) && "Expected a C expression");
+        assert(op.hasTrait<OpTrait::emitc::CExpression>() && "Expected a C expression");
         // Conservatively assume calls to read and write memory.
         if (isa<emitc::CallOpaqueOp>(op))
           return true;
@@ -518,7 +517,7 @@ def EmitC_ForOp : EmitC_Op<"for",
 }
 
 def EmitC_CallOp : EmitC_Op<"call",
-    [CallOpInterface,
+    [CallOpInterface, CExpression,
      DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let summary = "call operation";
   let description = [{
@@ -774,7 +773,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
   let assemblyFormat = "$value attr-dict `:` type($result)";
 }
 
-def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> {
+def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", [CExpression]> {
   let summary = "Logical and operation";
   let description = [{
     With the `logical_and` operation the logical operator && (and) can
@@ -795,7 +794,7 @@ def EmitC_LogicalAndOp : EmitC_BinaryOp<"logical_and", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> {
+def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", [CExpression]> {
   let summary = "Logical not operation";
   let description = [{
     With the `logical_not` operation the logical operator ! (negation) can
@@ -816,7 +815,7 @@ def EmitC_LogicalNotOp : EmitC_UnaryOp<"logical_not", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
+def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", [CExpression]> {
   let summary = "Logical or operation";
   let description = [{
     With the `logical_or` operation the logical operator || (inclusive or)
@@ -837,7 +836,7 @@ def EmitC_LogicalOrOp : EmitC_BinaryOp<"logical_or", []> {
   let assemblyFormat = "operands attr-dict `:` type(operands)";
 }
 
-def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
+def EmitC_MulOp : EmitC_BinaryOp<"mul", [CExpression]> {
   let summary = "Multiplication operation";
   let description = [{
     With the `mul` operation the arithmetic operator * (multiplication) can
@@ -861,7 +860,7 @@ def EmitC_MulOp : EmitC_BinaryOp<"mul", []> {
   let results = (outs FloatIntegerIndexOrOpaqueType);
 }
 
-def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
+def EmitC_RemOp : EmitC_BinaryOp<"rem", [CExpression]> {
   let summary = "Remainder operation";
   let description = [{
     With the `rem` operation the arithmetic operator % (remainder) can
@@ -883,7 +882,7 @@ def EmitC_RemOp : EmitC_BinaryOp<"rem", []> {
   let results = (outs IntegerIndexOrOpaqueType);
 }
 
-def EmitC_SubOp : EmitC_BinaryOp<"sub", []> {
+def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
   let summary = "Subtraction operation";
   let description = [{
     With the `sub` operation the arithmetic operator - (subtraction) can
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
new file mode 100644
index 00000000000000..c1602dfce4b484
--- /dev/null
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTraits.h
@@ -0,0 +1,30 @@
+//===- EmitCTraits.h - EmitC trait definitions ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of the traits used in the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
+#define MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace emitc {
+
+template <typename ConcreteType>
+class CExpression : public TraitBase<ConcreteType, CExpression> {};
+
+} // namespace emitc
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_EMITC_IR_EMITCTRAITS_H
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4df8149b94c95f..6678b7a35a39d8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
index 21212155ffb22f..5b03f81b305fd5 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
@@ -36,7 +36,7 @@ struct FormExpressionsPass
     // Wrap each C operator op with an expression op.
     OpBuilder builder(context);
     auto matchFun = [&](Operation *op) {
-      if (emitc::ExpressionOp::isCExpression(*op))
+      if (op->hasTrait<OpTrait::emitc::CExpression>())
         createExpression(op, builder);
     };
     rootOp->walk(matchFun);

@marbre marbre requested a review from jpienaar March 6, 2024 14:35
@marbre marbre force-pushed the emitc.cexpression branch from eb1bfca to 66bef38 Compare March 6, 2024 14:53
Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This is a nice simplification!

This adds a `CExpression` trait and replaces the `isCExpression()`
function.
@marbre marbre force-pushed the emitc.cexpression branch from 66bef38 to f0b3f8d Compare March 6, 2024 17:36
@marbre marbre merged commit 7c63431 into llvm:main Mar 7, 2024
@marbre marbre deleted the emitc.cexpression branch March 7, 2024 07:37
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Mar 11, 2024
This adds a `CExpression` trait and replaces the `isCExpression()`
function.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants