Skip to content

Commit 95ffa8a

Browse files
authored
[mlir][emitc] Restrict types in EmitC (#88391)
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
1 parent e2a72fa commit 95ffa8a

File tree

5 files changed

+118
-26
lines changed

5 files changed

+118
-26
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ namespace mlir {
3131
namespace emitc {
3232
void buildTerminatedBody(OpBuilder &builder, Location loc);
3333

34+
/// Determines whether \p type is valid in EmitC.
35+
bool isSupportedEmitCType(mlir::Type type);
36+
3437
/// Determines whether \p type is a valid integer type in EmitC.
3538
bool isSupportedIntegerType(mlir::Type type);
3639

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
3434
// Base class for unary operations.
3535
class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
3636
EmitC_Op<mnemonic, traits> {
37-
let arguments = (ins AnyType);
38-
let results = (outs AnyType);
37+
let arguments = (ins EmitCType);
38+
let results = (outs EmitCType);
3939
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
4040
}
4141

4242
// Base class for binary operations.
4343
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
4444
EmitC_Op<mnemonic, traits> {
45-
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
46-
let results = (outs AnyType);
45+
let arguments = (ins EmitCType:$lhs, EmitCType:$rhs);
46+
let results = (outs EmitCType);
4747
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
4848
}
4949

@@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
9797
}];
9898
let arguments = (ins
9999
Arg<StrAttr, "the operator to apply">:$applicableOperator,
100-
AnyType:$operand
100+
EmitCType:$operand
101101
);
102-
let results = (outs AnyType:$result);
102+
let results = (outs EmitCType:$result);
103103
let assemblyFormat = [{
104104
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
105105
}];
@@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
240240
Arg<StrAttr, "the C++ function to call">:$callee,
241241
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
242242
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
243-
Variadic<AnyType>:$operands
243+
Variadic<EmitCType>:$operands
244244
);
245-
let results = (outs Variadic<AnyType>);
245+
let results = (outs Variadic<EmitCType>);
246246
let builders = [
247247
OpBuilder<(ins
248248
"::mlir::TypeRange":$resultTypes,
@@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast",
284284
```
285285
}];
286286

287-
let arguments = (ins AnyType:$source);
288-
let results = (outs AnyType:$dest);
287+
let arguments = (ins EmitCType:$source);
288+
let results = (outs EmitCType:$dest);
289289
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
290290
}
291291

@@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
323323
}];
324324

325325
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
326-
AnyType:$lhs,
327-
AnyType:$rhs);
328-
let results = (outs AnyType);
326+
EmitCType:$lhs,
327+
EmitCType:$rhs);
328+
let results = (outs EmitCType);
329329

330330
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
331331
}
@@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
353353
}];
354354

355355
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
356-
let results = (outs AnyType);
356+
let results = (outs EmitCType);
357357

358358
let hasFolder = 1;
359359
let hasVerifier = 1;
@@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
423423
}];
424424

425425
let arguments = (ins UnitAttr:$do_not_inline);
426-
let results = (outs AnyType:$result);
426+
let results = (outs EmitCType:$result);
427427
let regions = (region SizedRegion<1>:$region);
428428

429429
let hasVerifier = 1;
@@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call",
531531
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
532532
```
533533
}];
534-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
535-
let results = (outs Variadic<AnyType>);
534+
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
535+
let results = (outs Variadic<EmitCType>);
536536

537537
let builders = [
538538
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
@@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
722722
}
723723
```
724724
}];
725-
let arguments = (ins Optional<AnyType>:$operand);
725+
let arguments = (ins Optional<EmitCType>:$operand);
726726

727727
let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
728728
let hasVerifier = 1;
@@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
766766
}];
767767

768768
let arguments = (ins StrAttr:$value);
769-
let results = (outs AnyType:$result);
769+
let results = (outs EmitCType:$result);
770770

771771
let hasVerifier = 1;
772772
let assemblyFormat = "$value attr-dict `:` type($result)";
@@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
932932
int32_t v6 = v3 ? v4 : v5;
933933
```
934934
}];
935-
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
936-
let results = (outs AnyType:$result);
935+
let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value);
936+
let results = (outs EmitCType:$result);
937937
let assemblyFormat = "operands attr-dict `:` type($result)";
938938
}
939939

@@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
10091009
}];
10101010

10111011
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
1012-
let results = (outs AnyType);
1012+
let results = (outs EmitCType);
10131013

10141014
let hasVerifier = 1;
10151015
}
@@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
10681068
```
10691069
}];
10701070

1071-
let arguments = (ins AnyType:$var, AnyType:$value);
1071+
let arguments = (ins EmitCType:$var, EmitCType:$value);
10721072
let results = (outs);
10731073

10741074
let hasVerifier = 1;
@@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
10891089
value is yielded.
10901090
}];
10911091

1092-
let arguments = (ins Optional<AnyType>:$result);
1092+
let arguments = (ins Optional<EmitCType>:$result);
10931093
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
10941094

10951095
let hasVerifier = 1;
@@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
11731173
EmitC_OpaqueType,
11741174
EmitC_PointerType]>,
11751175
"the value to subscript">:$value,
1176-
Variadic<AnyType>:$indices);
1177-
let results = (outs AnyType:$result);
1176+
Variadic<EmitCType>:$indices);
1177+
let results = (outs EmitCType:$result);
11781178

11791179
let builders = [
11801180
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
2222
// EmitC type definitions
2323
//===----------------------------------------------------------------------===//
2424

25+
def EmitCType : Type<CPred<"emitc::isSupportedEmitCType($_self)">,
26+
"type supported by EmitC">;
27+
2528
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
2629
"integer type supported by EmitC">;
2730

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
1111
#include "mlir/IR/Builders.h"
1212
#include "mlir/IR/BuiltinAttributes.h"
13+
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/IR/DialectImplementation.h"
1415
#include "mlir/IR/IRMapping.h"
16+
#include "mlir/IR/Types.h"
1517
#include "mlir/Interfaces/FunctionImplementation.h"
18+
#include "llvm/ADT/STLExtras.h"
1619
#include "llvm/ADT/StringExtras.h"
1720
#include "llvm/ADT/TypeSwitch.h"
21+
#include "llvm/Support/Casting.h"
1822

1923
using namespace mlir;
2024
using namespace mlir::emitc;
@@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
5458
builder.create<emitc::YieldOp>(loc);
5559
}
5660

61+
bool mlir::emitc::isSupportedEmitCType(Type type) {
62+
if (llvm::isa<emitc::OpaqueType>(type))
63+
return true;
64+
if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
65+
return isSupportedEmitCType(ptrType.getPointee());
66+
if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
67+
auto elemType = arrayType.getElementType();
68+
return !llvm::isa<emitc::ArrayType>(elemType) &&
69+
isSupportedEmitCType(elemType);
70+
}
71+
if (type.isIndex())
72+
return true;
73+
if (llvm::isa<IntegerType>(type))
74+
return isSupportedIntegerType(type);
75+
if (llvm::isa<FloatType>(type))
76+
return isSupportedFloatType(type);
77+
if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
78+
if (!tensorType.hasStaticShape()) {
79+
return false;
80+
}
81+
auto elemType = tensorType.getElementType();
82+
if (llvm::isa<emitc::ArrayType>(elemType)) {
83+
return false;
84+
}
85+
return isSupportedEmitCType(elemType);
86+
}
87+
if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
88+
return llvm::all_of(tupleType.getTypes(), [](Type type) {
89+
return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
90+
});
91+
}
92+
return false;
93+
}
94+
5795
bool mlir::emitc::isSupportedIntegerType(Type type) {
5896
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
5997
switch (intType.getWidth()) {

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
9797
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
9898
return
9999
}
100+
101+
// -----
102+
103+
func.func @illegal_pointee_type() {
104+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}}
105+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11>
106+
return
107+
}
108+
109+
// -----
110+
111+
func.func @illegal_non_static_tensor_shape_type() {
112+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<?xf32>'}}
113+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<?xf32>
114+
return
115+
}
116+
117+
// -----
118+
119+
func.func @illegal_tensor_array_element_type() {
120+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<!emitc.array<9xi16>>'}}
121+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<!emitc.array<9xi16>>
122+
return
123+
}
124+
125+
// -----
126+
127+
func.func @illegal_tensor_integer_element_type() {
128+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}}
129+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11>
130+
return
131+
}
132+
133+
// -----
134+
135+
func.func @illegal_tuple_array_element_type() {
136+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<!emitc.array<9xf32>, f32>'}}
137+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<!emitc.array<9xf32>, f32>
138+
return
139+
}
140+
141+
// -----
142+
143+
func.func @illegal_tuple_float_element_type() {
144+
// expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<i32, f80>'}}
145+
%v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<i32, f80>
146+
return
147+
}

0 commit comments

Comments
 (0)