-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][emitc] Restrict types in EmitC #88391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Tina Jung (TinaAMD) ChangesRestrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions. Full diff: https://github.com/llvm/llvm-project/pull/88391.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index c03915667db653..5d9531cd124154 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -31,6 +31,9 @@ namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+/// Determines whether \p type is valid in EmitC.
+bool isSupportedEmitCType(mlir::Type type);
+
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index e611fd2f0f15c4..c1a1e77c34ab25 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -34,16 +34,16 @@ class EmitC_Op<string mnemonic, list<Trait> traits = []>
// Base class for unary operations.
class EmitC_UnaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
- let arguments = (ins AnyType);
- let results = (outs AnyType);
+ let arguments = (ins EmitCType);
+ let results = (outs EmitCType);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
// Base class for binary operations.
class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
EmitC_Op<mnemonic, traits> {
- let arguments = (ins AnyType:$lhs, AnyType:$rhs);
- let results = (outs AnyType);
+ let arguments = (ins EmitCType:$lhs, EmitCType:$rhs);
+ let results = (outs EmitCType);
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
@@ -97,9 +97,9 @@ def EmitC_ApplyOp : EmitC_Op<"apply", [CExpression]> {
}];
let arguments = (ins
Arg<StrAttr, "the operator to apply">:$applicableOperator,
- AnyType:$operand
+ EmitCType:$operand
);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let assemblyFormat = [{
$applicableOperator `(` $operand `)` attr-dict `:` functional-type($operand, results)
}];
@@ -240,9 +240,9 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
Arg<StrAttr, "the C++ function to call">:$callee,
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
- Variadic<AnyType>:$operands
+ Variadic<EmitCType>:$operands
);
- let results = (outs Variadic<AnyType>);
+ let results = (outs Variadic<EmitCType>);
let builders = [
OpBuilder<(ins
"::mlir::TypeRange":$resultTypes,
@@ -284,8 +284,8 @@ def EmitC_CastOp : EmitC_Op<"cast",
```
}];
- let arguments = (ins AnyType:$source);
- let results = (outs AnyType:$dest);
+ let arguments = (ins EmitCType:$source);
+ let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
}
@@ -323,9 +323,9 @@ def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
}];
let arguments = (ins EmitC_CmpPredicateAttr:$predicate,
- AnyType:$lhs,
- AnyType:$rhs);
- let results = (outs AnyType);
+ EmitCType:$lhs,
+ EmitCType:$rhs);
+ let results = (outs EmitCType);
let assemblyFormat = "$predicate `,` operands attr-dict `:` functional-type(operands, results)";
}
@@ -353,7 +353,7 @@ def EmitC_ConstantOp : EmitC_Op<"constant", [ConstantLike]> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs AnyType);
+ let results = (outs EmitCType);
let hasFolder = 1;
let hasVerifier = 1;
@@ -423,7 +423,7 @@ def EmitC_ExpressionOp : EmitC_Op<"expression",
}];
let arguments = (ins UnitAttr:$do_not_inline);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let regions = (region SizedRegion<1>:$region);
let hasVerifier = 1;
@@ -531,8 +531,8 @@ def EmitC_CallOp : EmitC_Op<"call",
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
```
}];
- let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
- let results = (outs Variadic<AnyType>);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
+ let results = (outs Variadic<EmitCType>);
let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
@@ -722,7 +722,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">,
}
```
}];
- let arguments = (ins Optional<AnyType>:$operand);
+ let arguments = (ins Optional<EmitCType>:$operand);
let assemblyFormat = "attr-dict ($operand^ `:` type($operand))?";
let hasVerifier = 1;
@@ -766,7 +766,7 @@ def EmitC_LiteralOp : EmitC_Op<"literal", [Pure]> {
}];
let arguments = (ins StrAttr:$value);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let hasVerifier = 1;
let assemblyFormat = "$value attr-dict `:` type($result)";
@@ -932,8 +932,8 @@ def EmitC_ConditionalOp : EmitC_Op<"conditional",
int32_t v6 = v3 ? v4 : v5;
```
}];
- let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
- let results = (outs AnyType:$result);
+ let arguments = (ins I1:$condition, EmitCType:$true_value, EmitCType:$false_value);
+ let results = (outs EmitCType:$result);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
@@ -1009,7 +1009,7 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
}];
let arguments = (ins EmitC_OpaqueOrTypedAttr:$value);
- let results = (outs AnyType);
+ let results = (outs EmitCType);
let hasVerifier = 1;
}
@@ -1068,7 +1068,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
```
}];
- let arguments = (ins AnyType:$var, AnyType:$value);
+ let arguments = (ins EmitCType:$var, EmitCType:$value);
let results = (outs);
let hasVerifier = 1;
@@ -1089,7 +1089,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
value is yielded.
}];
- let arguments = (ins Optional<AnyType>:$result);
+ let arguments = (ins Optional<EmitCType>:$result);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let hasVerifier = 1;
@@ -1173,8 +1173,8 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
EmitC_OpaqueType,
EmitC_PointerType]>,
"the value to subscript">:$value,
- Variadic<AnyType>:$indices);
- let results = (outs AnyType:$result);
+ Variadic<EmitCType>:$indices);
+ let results = (outs EmitCType:$result);
let builders = [
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index bce5807230ce49..444395b915e250 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -22,6 +22,9 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//
+def EmitCType : Type<CPred<"emitc::isSupportedEmitCType($_self)">,
+ "type supported by EmitC">;
+
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
"integer type supported by EmitC">;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 7cbf28b627342a..b037ef3c0b4152 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -10,11 +10,15 @@
#include "mlir/Dialect/EmitC/IR/EmitCTraits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Types.h"
#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
using namespace mlir;
using namespace mlir::emitc;
@@ -54,6 +58,40 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
builder.create<emitc::YieldOp>(loc);
}
+bool mlir::emitc::isSupportedEmitCType(Type type) {
+ if (llvm::isa<emitc::OpaqueType>(type))
+ return true;
+ if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
+ return isSupportedEmitCType(ptrType.getPointee());
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
+ auto elemType = arrayType.getElementType();
+ return !llvm::isa<emitc::ArrayType>(elemType) &&
+ isSupportedEmitCType(elemType);
+ }
+ if (type.isIndex())
+ return true;
+ if (llvm::isa<IntegerType>(type))
+ return isSupportedIntegerType(type);
+ if (llvm::isa<FloatType>(type))
+ return isSupportedFloatType(type);
+ if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
+ if (!tensorType.hasStaticShape()) {
+ return false;
+ }
+ auto elemType = tensorType.getElementType();
+ if (llvm::isa<emitc::ArrayType>(elemType)) {
+ return false;
+ }
+ return isSupportedEmitCType(elemType);
+ }
+ if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
+ return llvm::all_of(tupleType.getTypes(), [](Type type) {
+ return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
+ });
+ }
+ return false;
+}
+
bool mlir::emitc::isSupportedIntegerType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
switch (intType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index f9d517bf689b95..0ad8d4eabe6b8b 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -97,3 +97,51 @@ func.func @illegal_float_type(%arg0: f80, %arg1: f80) {
%mul = "emitc.mul" (%arg0, %arg1) : (f80, f80) -> f80
return
}
+
+// -----
+
+func.func @illegal_pointee_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got '!emitc.ptr<i11>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i11>
+ return
+}
+
+// -----
+
+func.func @illegal_non_static_tensor_shape_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<?xf32>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<?xf32>
+ return
+}
+
+// -----
+
+func.func @illegal_tensor_array_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<!emitc.array<9xi16>>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<!emitc.array<9xi16>>
+ return
+}
+
+// -----
+
+func.func @illegal_tensor_integer_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tensor<9xi11>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tensor<9xi11>
+ return
+}
+
+// -----
+
+func.func @illegal_tuple_array_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<!emitc.array<9xf32>, f32>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<!emitc.array<9xf32>, f32>
+ return
+}
+
+// -----
+
+func.func @illegal_tuple_float_element_type() {
+ // expected-error @+1 {{'emitc.variable' op result #0 must be type supported by EmitC, but got 'tuple<i32, f80>'}}
+ %v = "emitc.variable"(){value = #emitc.opaque<"">} : () -> tuple<i32, f80>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a nice improvement, thanks!
@simon-camp do you have time to review? |
It is perfectly fine for us if either @simon-camp or me review. So feel free to just go ahead with merging after an approval of Simon or me. No need to wait for a second approval if you feel confident with this. |
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
Thanks for letting me know, I'll go with that policy then 🙂 |
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.
I aligned the restriction to the capabilities of the Emitter, such that one could only build EmitC code which the emitter can handle (to be able to find errors upon construction already). As far as I can see, the emitter doesn't support |
mlir-aie does not use the Emitter directly, but it does have a similar pass supporting vector and bf16 types, and that pass reuses ops from emitc dialect. Because those ops are mixed with "unsupported" types, the emitc verifier fails. To me it seems like the legality check for the Emitter should be in the Emitter, not in the dialect. Then the input to the Emitter can be generated by a multi-stage lowering, where the intermediate IR(s) might contain things not supported by the Emitter, but those things will be legalized away by subsequent transformations. |
I prefer the current design because the emitc dialects is exactly meant to represent what can be emitted into C. fyi @marbre |
We have had these discussions several times online and offline (already during upstreaming the dialect and the emitter), and the consensus was to make the dialect robust and keep the emitter simple. As @mgehre-amd mentioned, types that are not supported by the emitter can be represented with |
Thanks all for answering my questions and providing the possible workaround. |
Just to clarify, does this mean that the right way to have a call to a function with unsupported types is to first Edit: I've tried the casting approach, and the only thing that seems to be working for me is injecting |
emitc is about emitting C code, so the first question becomes: How should a |
Assuming you would expect a bfloat16 represented as follows in your C code:
When you convert from some dialect that uses bfloat16 to EmitC, you can add a type converter that describes what your type will look like in EmitC.
|
I see. I did understand, then. I believe we're indeed abusing the dialect because we don't convert everything to emitc before translating, we end up with a mix of dialects before translation so we end up with unreconciled unrealized casts. I will work around this the best I can, thank you both for your clarifications! |
Restrict the types which are valid for EmitC operations. Use what is currently supported by the emitter as restriction. Define a utility functions for valid types, such that they can be used to restrict the operations in table gen as well as being available for reuse in dialect conversions.