Skip to content

Commit d543c27

Browse files
committed
[mlir][emitc] Add ArrayType
This models a one or multi-dimensional C/C++ array. The type implements the ShapedTypeInterface and prints similar to memref/tensor: ``` %arg0: !emitc.array<1xf32>, %arg1: !emitc.array<10x20x30xi32>, %arg2: !emitc.array<30x!emitc.ptr<i32>>, %arg3: !emitc.array<30x!emitc.opaque<"int">> ``` It can be translated to C++ when used as function parameter or as emitc.variable type.
1 parent c89d511 commit d543c27

File tree

7 files changed

+215
-8
lines changed

7 files changed

+215
-8
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,64 @@
1616

1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
19+
include "mlir/IR/BuiltinTypeInterfaces.td"
1920

2021
//===----------------------------------------------------------------------===//
2122
// EmitC type definitions
2223
//===----------------------------------------------------------------------===//
2324

24-
class EmitC_Type<string name, string typeMnemonic>
25-
: TypeDef<EmitC_Dialect, name> {
25+
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
26+
: TypeDef<EmitC_Dialect, name, traits> {
2627
let mnemonic = typeMnemonic;
2728
}
2829

30+
def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
31+
let summary = "EmitC array type";
32+
33+
let description = [{
34+
An array data type.
35+
36+
Example:
37+
38+
```mlir
39+
// Array emitted as `int32_t[10]`
40+
!emitc.array<10xi32>
41+
// Array emitted as `float[10][20]`
42+
!emitc.ptr<10x20xf32>
43+
```
44+
}];
45+
46+
let parameters = (ins
47+
ArrayRefParameter<"int64_t">:$shape,
48+
"Type":$elementType
49+
);
50+
51+
let builders = [
52+
TypeBuilderWithInferredContext<(ins
53+
"ArrayRef<int64_t>":$shape,
54+
"Type":$elementType
55+
), [{
56+
return $_get(elementType.getContext(), shape, elementType);
57+
}]>
58+
];
59+
let extraClassDeclaration = [{
60+
/// Returns if this type is ranked (always true).
61+
bool hasRank() const { return true; }
62+
63+
/// Clone this vector type with the given shape and element type. If the
64+
/// provided shape is `std::nullopt`, the current shape of the type is used.
65+
ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
66+
Type elementType) const;
67+
68+
static bool isValidElementType(Type type) {
69+
return type.isIntOrIndexOrFloat() ||
70+
llvm::isa<PointerType, OpaqueType>(type);
71+
}
72+
}];
73+
let genVerifyDecl = 1;
74+
let hasCustomAssemblyFormat = 1;
75+
}
76+
2977
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
3078
let summary = "EmitC opaque type";
3179

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,69 @@ LogicalResult emitc::YieldOp::verify() {
762762
#define GET_TYPEDEF_CLASSES
763763
#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
764764

765+
//===----------------------------------------------------------------------===//
766+
// ArrayType
767+
//===----------------------------------------------------------------------===//
768+
769+
Type emitc::ArrayType::parse(AsmParser &parser) {
770+
if (parser.parseLess())
771+
return Type();
772+
773+
SmallVector<int64_t, 4> dimensions;
774+
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
775+
/*withTrailingX=*/true))
776+
return Type();
777+
// Parse the element type.
778+
auto typeLoc = parser.getCurrentLocation();
779+
Type elementType;
780+
if (parser.parseType(elementType))
781+
return Type();
782+
783+
// Check that memref is formed from allowed types.
784+
if (!isValidElementType(elementType))
785+
return parser.emitError(typeLoc, "invalid array element type"), Type();
786+
if (parser.parseGreater())
787+
return Type();
788+
return parser.getChecked<ArrayType>(dimensions, elementType);
789+
}
790+
791+
void emitc::ArrayType::print(AsmPrinter &printer) const {
792+
printer << "<";
793+
for (int64_t dim : getShape()) {
794+
printer << dim << 'x';
795+
}
796+
printer.printType(getElementType());
797+
printer << ">";
798+
}
799+
800+
LogicalResult emitc::ArrayType::verify(
801+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
802+
::llvm::ArrayRef<int64_t> shape, Type elementType) {
803+
if (shape.empty())
804+
return emitError() << "shape must not be empty";
805+
806+
for (auto d : shape) {
807+
if (d <= 0)
808+
return emitError() << "dimensions must have positive size";
809+
}
810+
811+
if (!elementType)
812+
return emitError() << "element type must not be none";
813+
814+
if (!isValidElementType(elementType))
815+
return emitError() << "invalid array element type";
816+
817+
return success();
818+
}
819+
820+
emitc::ArrayType
821+
emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
822+
Type elementType) const {
823+
if (!shape)
824+
return emitc::ArrayType::get(getShape(), elementType);
825+
return emitc::ArrayType::get(*shape, elementType);
826+
}
827+
765828
//===----------------------------------------------------------------------===//
766829
// OpaqueType
767830
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ struct CppEmitter {
128128
LogicalResult emitVariableDeclaration(OpResult result,
129129
bool trailingSemicolon);
130130

131+
/// Emits a declaration of a variable with the given type and name.
132+
LogicalResult emitVariableDeclaration(Location loc, Type type,
133+
StringRef name);
134+
131135
/// Emits the variable declaration and assignment prefix for 'op'.
132136
/// - emits separate variable followed by std::tie for multi-valued operation;
133137
/// - emits single type followed by variable for single result;
@@ -783,10 +787,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
783787

784788
return (interleaveCommaWithError(
785789
arguments, os, [&](BlockArgument arg) -> LogicalResult {
786-
if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
787-
return failure();
788-
os << " " << emitter.getOrCreateName(arg);
789-
return success();
790+
return emitter.emitVariableDeclaration(
791+
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
790792
}));
791793
}
792794

@@ -1219,9 +1221,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
12191221
return result.getDefiningOp()->emitError(
12201222
"result variable for the operation already declared");
12211223
}
1222-
if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
1224+
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
1225+
result.getType(),
1226+
getOrCreateName(result))))
12231227
return failure();
1224-
os << " " << getOrCreateName(result);
12251228
if (trailingSemicolon)
12261229
os << ";\n";
12271230
return success();
@@ -1314,6 +1317,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
13141317
return success();
13151318
}
13161319

1320+
LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
1321+
StringRef name) {
1322+
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
1323+
if (failed(emitType(loc, arrType.getElementType())))
1324+
return failure();
1325+
os << " " << name;
1326+
for (auto dim : arrType.getShape()) {
1327+
os << "[" << dim << "]";
1328+
}
1329+
return success();
1330+
}
1331+
if (failed(emitType(loc, type)))
1332+
return failure();
1333+
os << " " << name;
1334+
return success();
1335+
}
1336+
13171337
LogicalResult CppEmitter::emitType(Location loc, Type type) {
13181338
if (auto iType = dyn_cast<IntegerType>(type)) {
13191339
switch (iType.getWidth()) {

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() {
1111
// expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}}
1212
%1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*">
1313
}
14+
15+
// -----
16+
17+
func.func @illegal_array_missing_spec(
18+
// expected-error @+1 {{expected non-function type}}
19+
%arg0: !emitc.array<>) {
20+
}
21+
22+
// -----
23+
24+
func.func @illegal_array_missing_shape(
25+
// expected-error @+1 {{shape must not be empty}}
26+
%arg9: !emitc.array<i32>) {
27+
}
28+
29+
// -----
30+
31+
func.func @illegal_array_missing_x(
32+
// expected-error @+1 {{expected 'x' in dimension list}}
33+
%arg0: !emitc.array<10>
34+
) {
35+
}
36+
37+
// -----
38+
39+
func.func @illegal_array_non_positive_dimenson(
40+
// expected-error @+1 {{dimensions must have positive size}}
41+
%arg0: !emitc.array<0xi32>
42+
) {
43+
}
44+
45+
// -----
46+
47+
func.func @illegal_array_missing_type(
48+
// expected-error @+1 {{expected non-function type}}
49+
%arg0: !emitc.array<10x>
50+
) {
51+
}
52+
53+
// -----
54+
55+
func.func @illegal_array_dynamic_shape(
56+
// expected-error @+1 {{expected static shape}}
57+
%arg0: !emitc.array<10x?xi32>
58+
) {
59+
}
60+
61+
// -----
62+
63+
func.func @illegal_array_unranked(
64+
// expected-error @+1 {{expected non-function type}}
65+
%arg0: !emitc.array<*xi32>
66+
) {
67+
}

mlir/test/Dialect/EmitC/types.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,17 @@ func.func @pointer_types() {
3939

4040
return
4141
}
42+
43+
// CHECK-LABEL: func @array_types(
44+
func.func @array_types(
45+
// CHECK-SAME: !emitc.array<1xf32>,
46+
%arg0: !emitc.array<1xf32>,
47+
// CHECK-SAME: !emitc.array<10x20x30xi32>,
48+
%arg1: !emitc.array<10x20x30xi32>,
49+
// CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
50+
%arg2: !emitc.array<30x!emitc.ptr<i32>>,
51+
// CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
52+
%arg3: !emitc.array<30x!emitc.opaque<"int">>
53+
) {
54+
return
55+
}

mlir/test/Target/Cpp/common-cpp.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
8989
%1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
9090
return %0 : !emitc.ptr<i32>
9191
}
92+
93+
// CHECK: void array_type(int32_t v1[3], float v2[10][20])
94+
func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
95+
return
96+
}

mlir/test/Target/Cpp/variable.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ func.func @emitc_variable() {
99
%c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8
1010
%c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
1111
%c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
12+
%c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
1213
return
1314
}
1415
// CPP-DEFAULT: void emitc_variable() {
@@ -19,6 +20,7 @@ func.func @emitc_variable() {
1920
// CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
2021
// CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
2122
// CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
23+
// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];
2224

2325
// CPP-DECLTOP: void emitc_variable() {
2426
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -28,6 +30,7 @@ func.func @emitc_variable() {
2830
// CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
2931
// CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
3032
// CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
33+
// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
3134
// CPP-DECLTOP-NEXT: ;
3235
// CPP-DECLTOP-NEXT: [[V1]] = 42;
3336
// CPP-DECLTOP-NEXT: [[V2]] = -1;

0 commit comments

Comments
 (0)