-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][emitc] Restrict integer and float types #85788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][emitc] Restrict integer and float types #85788
Conversation
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Tina Jung (TinaAMD) ChangesRestrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC. Full diff: https://github.com/llvm/llvm-project/pull/85788.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 1f0df3cb336b12..f3d250fbdc2863 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,6 +30,10 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+/// Determines whether \p type is a valid integer type in EmitC.
+bool isValidEmitCIntegerType(mlir::Type type);
+/// Determines whether \p type is a valid floating-point type in EmitC.
+bool isValidEmitCFloatType(mlir::Type type);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 78bfd561171f50..8e6a8d48b0f744 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;
// Types only used in binary arithmetic operations.
-def IntegerIndexOrOpaqueType : AnyTypeOf<[AnyInteger, Index, EmitC_OpaqueType]>;
-def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[AnyFloat, IntegerIndexOrOpaqueType]>;
+def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
+def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index a2ba45a1f6a12b..ed51c6d05d567e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,12 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//
+def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
+ "EmitC integer type">;
+
+def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
+ "EmitC floating-point type">;
+
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<EmitC_Dialect, name, traits> {
let mnemonic = typeMnemonic;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e401a83bcb42e6..f6e2168e0bac7b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -54,6 +54,35 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}
+bool mlir::emitc::isValidEmitCIntegerType(Type type) {
+ if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
+ switch (intType.getWidth()) {
+ case 1:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
+bool mlir::emitc::isValidEmitCFloatType(Type type) {
+ if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
+ switch (floatType.getWidth()) {
+ case 32:
+ case 64:
+ return true;
+ default:
+ return false;
+ }
+ }
+ return false;
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 6294c853d99931..6cc833cf396462 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -170,7 +170,7 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
// -----
func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.div' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.div' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -178,7 +178,7 @@ func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point or integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.mul' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -186,7 +186,7 @@ func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
- // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'tensor<i32>'}}
+ // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
return
}
@@ -194,7 +194,7 @@ func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// -----
func.func @rem_float(%arg0: f32, %arg1: f32) {
- // expected-error @+1 {{'emitc.rem' op operand #0 must be integer or index or EmitC opaque type, but got 'f32'}}
+ // expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'f32'}}
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this @TinaAMD! Some naming suggestions. Furthermore, could you please add some tests with invalid types? That would be really helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, lgtm! Let me know if I should merge on your behalf.
I got commit access recently, so I should be able to merge, but thank you! I'll give @simon-camp some time to give feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports.
The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.