Skip to content

Commit 818af71

Browse files
authored
[mlir][emitc] Add ArrayType (#83386)
This models a one or multi-dimensional C/C++ array. The type implements the `ShapedTypeInterface` and prints similar to memref/tensor: ``` %arg0: !emitc.array<1xf32>, %arg1: !emitc.array<10x20x30xi32>, %arg2: !emitc.array<30x!emitc.ptr<i32>>, %arg3: !emitc.array<30x!emitc.opaque<"int">> ``` It can be translated to a C array type when used as function parameter or as `emitc.variable` type.
1 parent d99bb01 commit 818af71

File tree

12 files changed

+346
-8
lines changed

12 files changed

+346
-8
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,64 @@
1616

1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
19+
include "mlir/IR/BuiltinTypeInterfaces.td"
1920

2021
//===----------------------------------------------------------------------===//
2122
// EmitC type definitions
2223
//===----------------------------------------------------------------------===//
2324

24-
class EmitC_Type<string name, string typeMnemonic>
25-
: TypeDef<EmitC_Dialect, name> {
25+
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
26+
: TypeDef<EmitC_Dialect, name, traits> {
2627
let mnemonic = typeMnemonic;
2728
}
2829

30+
def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
31+
let summary = "EmitC array type";
32+
33+
let description = [{
34+
An array data type.
35+
36+
Example:
37+
38+
```mlir
39+
// Array emitted as `int32_t[10]`
40+
!emitc.array<10xi32>
41+
// Array emitted as `float[10][20]`
42+
!emitc.ptr<10x20xf32>
43+
```
44+
}];
45+
46+
let parameters = (ins
47+
ArrayRefParameter<"int64_t">:$shape,
48+
"Type":$elementType
49+
);
50+
51+
let builders = [
52+
TypeBuilderWithInferredContext<(ins
53+
"ArrayRef<int64_t>":$shape,
54+
"Type":$elementType
55+
), [{
56+
return $_get(elementType.getContext(), shape, elementType);
57+
}]>
58+
];
59+
let extraClassDeclaration = [{
60+
/// Returns if this type is ranked (always true).
61+
bool hasRank() const { return true; }
62+
63+
/// Clone this array type with the given shape and element type. If the
64+
/// provided shape is `std::nullopt`, the current shape of the type is used.
65+
ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
66+
Type elementType) const;
67+
68+
static bool isValidElementType(Type type) {
69+
return type.isIntOrIndexOrFloat() ||
70+
llvm::isa<PointerType, OpaqueType>(type);
71+
}
72+
}];
73+
let genVerifyDecl = 1;
74+
let hasCustomAssemblyFormat = 1;
75+
}
76+
2977
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
3078
let summary = "EmitC opaque type";
3179

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ LogicalResult emitc::AssignOp::verify() {
141141
return emitOpError() << "requires value's type (" << value.getType()
142142
<< ") to match variable's type (" << variable.getType()
143143
<< ")";
144+
if (isa<ArrayType>(variable.getType()))
145+
return emitOpError() << "cannot assign to array type";
144146
return success();
145147
}
146148

@@ -192,6 +194,11 @@ LogicalResult emitc::CallOpaqueOp::verify() {
192194
}
193195
}
194196

197+
if (llvm::any_of(getResultTypes(),
198+
[](Type type) { return isa<ArrayType>(type); })) {
199+
return emitOpError() << "cannot return array type";
200+
}
201+
195202
return success();
196203
}
197204

@@ -456,6 +463,9 @@ LogicalResult FuncOp::verify() {
456463
return emitOpError("requires zero or exactly one result, but has ")
457464
<< getNumResults();
458465

466+
if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
467+
return emitOpError("cannot return array type");
468+
459469
return success();
460470
}
461471

@@ -763,6 +773,69 @@ LogicalResult emitc::YieldOp::verify() {
763773
#define GET_TYPEDEF_CLASSES
764774
#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
765775

776+
//===----------------------------------------------------------------------===//
777+
// ArrayType
778+
//===----------------------------------------------------------------------===//
779+
780+
Type emitc::ArrayType::parse(AsmParser &parser) {
781+
if (parser.parseLess())
782+
return Type();
783+
784+
SmallVector<int64_t, 4> dimensions;
785+
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
786+
/*withTrailingX=*/true))
787+
return Type();
788+
// Parse the element type.
789+
auto typeLoc = parser.getCurrentLocation();
790+
Type elementType;
791+
if (parser.parseType(elementType))
792+
return Type();
793+
794+
// Check that array is formed from allowed types.
795+
if (!isValidElementType(elementType))
796+
return parser.emitError(typeLoc, "invalid array element type"), Type();
797+
if (parser.parseGreater())
798+
return Type();
799+
return parser.getChecked<ArrayType>(dimensions, elementType);
800+
}
801+
802+
void emitc::ArrayType::print(AsmPrinter &printer) const {
803+
printer << "<";
804+
for (int64_t dim : getShape()) {
805+
printer << dim << 'x';
806+
}
807+
printer.printType(getElementType());
808+
printer << ">";
809+
}
810+
811+
LogicalResult emitc::ArrayType::verify(
812+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
813+
::llvm::ArrayRef<int64_t> shape, Type elementType) {
814+
if (shape.empty())
815+
return emitError() << "shape must not be empty";
816+
817+
for (int64_t dim : shape) {
818+
if (dim <= 0)
819+
return emitError() << "dimensions must have positive size";
820+
}
821+
822+
if (!elementType)
823+
return emitError() << "element type must not be none";
824+
825+
if (!isValidElementType(elementType))
826+
return emitError() << "invalid array element type";
827+
828+
return success();
829+
}
830+
831+
emitc::ArrayType
832+
emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
833+
Type elementType) const {
834+
if (!shape)
835+
return emitc::ArrayType::get(getShape(), elementType);
836+
return emitc::ArrayType::get(*shape, elementType);
837+
}
838+
766839
//===----------------------------------------------------------------------===//
767840
// OpaqueType
768841
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ struct CppEmitter {
139139
LogicalResult emitVariableDeclaration(OpResult result,
140140
bool trailingSemicolon);
141141

142+
/// Emits a declaration of a variable with the given type and name.
143+
LogicalResult emitVariableDeclaration(Location loc, Type type,
144+
StringRef name);
145+
142146
/// Emits the variable declaration and assignment prefix for 'op'.
143147
/// - emits separate variable followed by std::tie for multi-valued operation;
144148
/// - emits single type followed by variable for single result;
@@ -870,10 +874,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
870874

871875
return (interleaveCommaWithError(
872876
arguments, os, [&](BlockArgument arg) -> LogicalResult {
873-
if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
874-
return failure();
875-
os << " " << emitter.getOrCreateName(arg);
876-
return success();
877+
return emitter.emitVariableDeclaration(
878+
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
877879
}));
878880
}
879881

@@ -917,6 +919,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
917919
if (emitter.hasValueInScope(arg))
918920
return functionOp->emitOpError(" block argument #")
919921
<< arg.getArgNumber() << " is out of scope";
922+
if (isa<ArrayType>(arg.getType()))
923+
return functionOp->emitOpError("cannot emit block argument #")
924+
<< arg.getArgNumber() << " with array type";
920925
if (failed(
921926
emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
922927
return failure();
@@ -960,6 +965,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
960965
"with multiple blocks needs variables declared at top");
961966
}
962967

968+
if (llvm::any_of(functionOp.getResultTypes(),
969+
[](Type type) { return isa<ArrayType>(type); })) {
970+
return functionOp.emitOpError() << "cannot emit array type as result type";
971+
}
972+
963973
CppEmitter::Scope scope(emitter);
964974
raw_indented_ostream &os = emitter.ostream();
965975
if (failed(emitter.emitTypes(functionOp.getLoc(),
@@ -1306,9 +1316,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
13061316
return result.getDefiningOp()->emitError(
13071317
"result variable for the operation already declared");
13081318
}
1309-
if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
1319+
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1320+
result.getType(),
1321+
getOrCreateName(result))))
13101322
return failure();
1311-
os << " " << getOrCreateName(result);
13121323
if (trailingSemicolon)
13131324
os << ";\n";
13141325
return success();
@@ -1403,6 +1414,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14031414
return success();
14041415
}
14051416

1417+
LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1418+
StringRef name) {
1419+
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1420+
if (failed(emitType(loc, arrType.getElementType())))
1421+
return failure();
1422+
os << " " << name;
1423+
for (auto dim : arrType.getShape()) {
1424+
os << "[" << dim << "]";
1425+
}
1426+
return success();
1427+
}
1428+
if (failed(emitType(loc, type)))
1429+
return failure();
1430+
os << " " << name;
1431+
return success();
1432+
}
1433+
14061434
LogicalResult CppEmitter::emitType(Location loc, Type type) {
14071435
if (auto iType = dyn_cast<IntegerType>(type)) {
14081436
switch (iType.getWidth()) {
@@ -1438,6 +1466,8 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
14381466
if (!tType.hasStaticShape())
14391467
return emitError(loc, "cannot emit tensor type with non static shape");
14401468
os << "Tensor<";
1469+
if (isa<ArrayType>(tType.getElementType()))
1470+
return emitError(loc, "cannot emit tensor of array type ") << type;
14411471
if (failed(emitType(loc, tType.getElementType())))
14421472
return failure();
14431473
auto shape = tType.getShape();
@@ -1454,7 +1484,16 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
14541484
os << oType.getValue();
14551485
return success();
14561486
}
1487+
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
1488+
if (failed(emitType(loc, aType.getElementType())))
1489+
return failure();
1490+
for (auto dim : aType.getShape())
1491+
os << "[" << dim << "]";
1492+
return success();
1493+
}
14571494
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
1495+
if (isa<ArrayType>(pType.getPointee()))
1496+
return emitError(loc, "cannot emit pointer to array type ") << type;
14581497
if (failed(emitType(loc, pType.getPointee())))
14591498
return failure();
14601499
os << "*";
@@ -1476,6 +1515,9 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
14761515
}
14771516

14781517
LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
1518+
if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) {
1519+
return emitError(loc, "cannot emit tuple of array type");
1520+
}
14791521
os << "std::tuple<";
14801522
if (failed(interleaveCommaWithError(
14811523
types, os, [&](Type type) { return emitType(loc, type); })))

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ func.func @dense_template_argument(%arg : i32) {
8080

8181
// -----
8282

83+
func.func @array_result() {
84+
// expected-error @+1 {{'emitc.call_opaque' op cannot return array type}}
85+
emitc.call_opaque "array_result"() : () -> !emitc.array<4xi32>
86+
return
87+
}
88+
89+
// -----
90+
8391
func.func @empty_operator(%arg : i32) {
8492
// expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
8593
%2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
@@ -129,6 +137,14 @@ func.func @cast_tensor(%arg : tensor<f32>) {
129137

130138
// -----
131139

140+
func.func @cast_array(%arg : !emitc.array<4xf32>) {
141+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
142+
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
143+
return
144+
}
145+
146+
// -----
147+
132148
func.func @add_two_pointers(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
133149
// expected-error @+1 {{'emitc.add' op requires that at most one operand is a pointer}}
134150
%1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
@@ -235,6 +251,15 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
235251

236252
// -----
237253

254+
func.func @test_assign_to_array(%arg1: !emitc.array<4xi32>) {
255+
%v = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi32>
256+
// expected-error @+1 {{'emitc.assign' op cannot assign to array type}}
257+
emitc.assign %arg1 : !emitc.array<4xi32> to %v : !emitc.array<4xi32>
258+
return
259+
}
260+
261+
// -----
262+
238263
func.func @test_expression_no_yield() -> i32 {
239264
// expected-error @+1 {{'emitc.expression' op must yield a value at termination}}
240265
%r = emitc.expression : i32 {
@@ -313,6 +338,13 @@ emitc.func @return_type_mismatch() -> i32 {
313338

314339
// -----
315340

341+
// expected-error@+1 {{'emitc.func' op cannot return array type}}
342+
emitc.func @return_type_array(%arg : !emitc.array<4xi32>) -> !emitc.array<4xi32> {
343+
emitc.return %arg : !emitc.array<4xi32>
344+
}
345+
346+
// -----
347+
316348
func.func @return_inside_func.func(%0: i32) -> (i32) {
317349
// expected-error@+1 {{'emitc.return' op expects parent op 'emitc.func'}}
318350
emitc.return %0 : i32

0 commit comments

Comments
 (0)