Skip to content

[mlir][emitc] Restrict integer and float types #85788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 20, 2024

Conversation

TinaAMD
Copy link
Contributor

@TinaAMD TinaAMD commented Mar 19, 2024

Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.

The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.

Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.

The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Tina Jung (TinaAMD)

Changes

Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.

The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.


Full diff: https://github.com/llvm/llvm-project/pull/85788.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+4)
  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td (+6)
  • (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+29)
  • (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+4-4)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 1f0df3cb336b12..f3d250fbdc2863 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,6 +30,10 @@
 namespace mlir {
 namespace emitc {
 void buildTerminatedBody(OpBuilder &builder, Location loc);
+/// Determines whether \p type is a valid integer type in EmitC.
+bool isValidEmitCIntegerType(mlir::Type type);
+/// Determines whether \p type is a valid floating-point type in EmitC.
+bool isValidEmitCFloatType(mlir::Type type);
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 78bfd561171f50..8e6a8d48b0f744 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
 def CExpression : NativeOpTrait<"emitc::CExpression">;
 
 // Types only used in binary arithmetic operations.
-def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
-def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
+def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
+def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
 
 def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
   let summary = "Addition operation";
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index a2ba45a1f6a12b..ed51c6d05d567e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,12 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
 // EmitC type definitions
 //===----------------------------------------------------------------------===//
 
+def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
+    "EmitC integer type">;
+
+def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
+    "EmitC floating-point type">;
+
 class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
     : TypeDef<EmitC_Dialect, name, traits> {
   let mnemonic = typeMnemonic;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e401a83bcb42e6..f6e2168e0bac7b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -54,6 +54,35 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
   builder.create<emitc::YieldOp>(loc);
 }
 
+bool mlir::emitc::isValidEmitCIntegerType(Type type) {
+  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+    switch (intType.getWidth()) {
+    case 1:
+    case 8:
+    case 16:
+    case 32:
+    case 64:
+      return true;
+    default:
+      return false;
+    }
+  }
+  return false;
+}
+
+bool mlir::emitc::isValidEmitCFloatType(Type type) {
+  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+    switch (floatType.getWidth()) {
+    case 32:
+    case 64:
+      return true;
+    default:
+      return false;
+    }
+  }
+  return false;
+}
+
 /// Check that the type of the initial value is compatible with the operations
 /// result type.
 static LogicalResult verifyInitializationAttribute(Operation *op,
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 6294c853d99931..6cc833cf396462 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -170,7 +170,7 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
 // -----
 
 func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
-    // expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+    // expected-error @+1 {{'emitc.div' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
     %1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
     return
 }
@@ -178,7 +178,7 @@ func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
 // -----
 
 func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
-    // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+    // expected-error @+1 {{'emitc.mul' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
     %1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
     return
 }
@@ -186,7 +186,7 @@ func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
 // -----
 
 func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
-    // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+    // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
     %1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
     return
 }
@@ -194,7 +194,7 @@ func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
 // -----
 
 func.func @rem_float(%arg0: f32, %arg1: f32) {
-    // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}}
+    // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'f32'}}
     %1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
     return
 }

@TinaAMD TinaAMD requested review from marbre and simon-camp March 19, 2024 14:03
Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this @TinaAMD! Some naming suggestions. Furthermore, could you please add some tests with invalid types? That would be really helpful.

@TinaAMD TinaAMD requested a review from marbre March 20, 2024 10:28
Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, lgtm! Let me know if I should merge on your behalf.

@TinaAMD
Copy link
Contributor Author

TinaAMD commented Mar 20, 2024

I got commit access recently, so I should be able to merge, but thank you! I'll give @simon-camp some time to give feedback.

Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@TinaAMD TinaAMD merged commit 647d75d into llvm:main Mar 20, 2024
@TinaAMD TinaAMD deleted the tina.restrict-emitc-float-integer-types branch March 20, 2024 15:00
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Mar 22, 2024
Restrict which integers and floating-point types are valid in EmitC.
This should cover the types which are supported in C++ and is aligned
with what the emitter currently supports.

The checks are implemented as functions and not fully in tablegen to
allow them to be re-used by conversions to EmitC.
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
Restrict which integers and floating-point types are valid in EmitC.
This should cover the types which are supported in C++ and is aligned
with what the emitter currently supports.

The checks are implemented as functions and not fully in tablegen to
allow them to be re-used by conversions to EmitC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants