Skip to content

Commit 3d166b2

Browse files
TinaAMDmgehre-amd
authored andcommitted
[mlir][emitc] Restrict integer and float types (llvm#85788)
Restrict which integers and floating-point types are valid in EmitC. This should cover the types which are supported in C++ and is aligned with what the emitter currently supports. The checks are implemented as functions and not fully in tablegen to allow them to be re-used by conversions to EmitC.
1 parent b9fa1b8 commit 3d166b2

File tree

6 files changed

+30
-14
lines changed

6 files changed

+30
-14
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ namespace mlir {
3131
namespace emitc {
3232
void buildTerminatedBody(OpBuilder &builder, Location loc);
3333
/// Determines whether \p type is a valid integer type in EmitC.
34-
bool isValidEmitCIntegerType(mlir::Type type);
34+
bool isSupportedIntegerType(mlir::Type type);
3535
/// Determines whether \p type is a valid floating-point type in EmitC.
36-
bool isValidEmitCFloatType(mlir::Type type);
36+
bool isSupportedFloatType(mlir::Type type);
3737
} // namespace emitc
3838
} // namespace mlir
3939

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
5151
def CExpression : NativeOpTrait<"emitc::CExpression">;
5252

5353
// Types only used in binary arithmetic operations.
54-
def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
55-
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
54+
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
55+
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
5656

5757
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
5858
let summary = "Addition operation";

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
2222
// EmitC type definitions
2323
//===----------------------------------------------------------------------===//
2424

25-
def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
26-
"EmitC integer type">;
25+
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
26+
"integer type supported by EmitC">;
2727

28-
def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
29-
"EmitC floating-point type">;
28+
def EmitCFloatType : Type<CPred<"emitc::isSupportedFloatType($_self)">,
29+
"floating-point type supported by EmitC">;
3030

3131
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
3232
: TypeDef<EmitC_Dialect, name, traits> {

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
5454
builder.create<emitc::YieldOp>(loc);
5555
}
5656

57-
bool mlir::emitc::isValidEmitCIntegerType(Type type) {
57+
bool mlir::emitc::isSupportedIntegerType(Type type) {
5858
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
5959
switch (intType.getWidth()) {
6060
case 1:
@@ -70,7 +70,7 @@ bool mlir::emitc::isValidEmitCIntegerType(Type type) {
7070
return false;
7171
}
7272

73-
bool mlir::emitc::isValidEmitCFloatType(Type type) {
73+
bool mlir::emitc::isSupportedFloatType(Type type) {
7474
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
7575
switch (floatType.getWidth()) {
7676
case 32:

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,31 +170,31 @@ func.func @add_float_pointer(%arg0: f32, %arg1: !emitc.ptr<f32>) {
170170
// -----
171171

172172
func.func @div_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
173-
// expected-error @+1 {{'emitc.div' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
173+
// expected-error @+1 {{'emitc.div' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
174174
%1 = "emitc.div" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
175175
return
176176
}
177177

178178
// -----
179179

180180
func.func @mul_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
181-
// expected-error @+1 {{'emitc.mul' op operand #0 must be EmitC floating-point type or EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
181+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
182182
%1 = "emitc.mul" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
183183
return
184184
}
185185

186186
// -----
187187

188188
func.func @rem_tensor(%arg0: tensor<i32>, %arg1: tensor<i32>) {
189-
// expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'tensor<i32>'}}
189+
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'tensor<i32>'}}
190190
%1 = "emitc.rem" (%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
191191
return
192192
}
193193

194194
// -----
195195

196196
func.func @rem_float(%arg0: f32, %arg1: f32) {
197-
// expected-error @+1 {{'emitc.rem' op operand #0 must be EmitC integer type or index or EmitC opaque type, but got 'f32'}}
197+
// expected-error @+1 {{'emitc.rem' op operand #0 must be integer type supported by EmitC or index or EmitC opaque type, but got 'f32'}}
198198
%1 = "emitc.rem" (%arg0, %arg1) : (f32, f32) -> f32
199199
return
200200
}

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,19 @@ func.func @illegal_array_with_tensor_element_type(
8181
%arg0: !emitc.array<4xtensor<4xi32>>
8282
) {
8383
}
84+
85+
// -----
86+
87+
func.func @illegal_integer_type(%arg0: i11, %arg1: i11) -> i11 {
88+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'i11'}}
89+
%mul = "emitc.mul" (%arg0, %arg1) : (i11, i11) -> i11
90+
return
91+
}
92+
93+
// -----
94+
95+
func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
96+
// expected-error @+1 {{'emitc.mul' op operand #0 must be floating-point type supported by EmitC or integer type supported by EmitC or index or EmitC opaque type, but got 'f80'}}
97+
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
98+
return
99+
}

0 commit comments

Comments
 (0)