Skip to content

Commit 8ebb681

Browse files
Simon Camphausenmgehre-amd
authored andcommitted
[mlir][EmitC] Add support for pointer and opaque types to subscript op (llvm#86266)
For pointer types the indices are restricted to one integer-like operand. For opaque types no further restrictions are made.
1 parent 79f313a commit 8ebb681

File tree

9 files changed

+176
-31
lines changed

9 files changed

+176
-31
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@
3030
namespace mlir {
3131
namespace emitc {
3232
void buildTerminatedBody(OpBuilder &builder, Location loc);
33+
3334
/// Determines whether \p type is valid in EmitC.
3435
bool isSupportedEmitCType(mlir::Type type);
36+
3537
/// Determines whether \p type is a valid integer type in EmitC.
3638
bool isSupportedIntegerType(mlir::Type type);
39+
40+
/// Determines whether \p type is integer like, i.e. it's a supported integer,
41+
/// an index or opaque type.
42+
bool isIntegerIndexOrOpaqueType(Type type);
43+
3744
/// Determines whether \p type is a valid floating-point type in EmitC.
3845
bool isSupportedFloatType(mlir::Type type);
3946
} // namespace emitc

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,35 +1224,41 @@ def EmitC_IfOp : EmitC_Op<"if",
12241224
let hasCustomAssemblyFormat = 1;
12251225
}
12261226

1227-
def EmitC_SubscriptOp : EmitC_Op<"subscript",
1228-
[TypesMatchWith<"result type matches element type of 'array'",
1229-
"array", "result",
1230-
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
1231-
let summary = "Array subscript operation";
1227+
def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
1228+
let summary = "Subscript operation";
12321229
let description = [{
12331230
With the `subscript` operation the subscript operator `[]` can be applied
1234-
to variables or arguments of array type.
1231+
to variables or arguments of array, pointer and opaque type.
12351232

12361233
Example:
12371234

12381235
```mlir
12391236
%i = index.constant 1
12401237
%j = index.constant 7
1241-
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
1238+
%0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
1239+
%1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
12421240
```
12431241
}];
1244-
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
1245-
Variadic<IntegerIndexOrOpaqueType>:$indices);
1242+
let arguments = (ins Arg<AnyTypeOf<[
1243+
EmitC_ArrayType,
1244+
EmitC_OpaqueType,
1245+
EmitC_PointerType]>,
1246+
"the value to subscript">:$value,
1247+
Variadic<AnyType>:$indices);
12461248
let results = (outs EmitCType:$result);
12471249

12481250
let builders = [
1249-
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
1250-
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
1251+
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
1252+
build($_builder, $_state, array.getType().getElementType(), array, indices);
1253+
}]>,
1254+
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
1255+
build($_builder, $_state, pointer.getType().getPointee(), pointer,
1256+
ValueRange{index});
12511257
}]>
12521258
];
12531259

12541260
let hasVerifier = 1;
1255-
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
1261+
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
12561262
}
12571263

12581264

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
124124
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
125125
}
126126

127+
auto arrayValue =
128+
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
129+
if (!arrayValue) {
130+
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
131+
}
132+
127133
auto subscript = rewriter.create<emitc::SubscriptOp>(
128-
op.getLoc(), operands.getMemref(), operands.getIndices());
134+
op.getLoc(), arrayValue, operands.getIndices());
129135

130136
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
131137
auto var =
@@ -143,9 +149,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
143149
LogicalResult
144150
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
145151
ConversionPatternRewriter &rewriter) const override {
152+
auto arrayValue =
153+
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
154+
if (!arrayValue) {
155+
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
156+
}
146157

147158
auto subscript = rewriter.create<emitc::SubscriptOp>(
148-
op.getLoc(), operands.getMemref(), operands.getIndices());
159+
op.getLoc(), arrayValue, operands.getIndices());
149160
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
150161
operands.getValue());
151162
return success();

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
108108
return false;
109109
}
110110

111+
bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
112+
return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
113+
isSupportedIntegerType(type);
114+
}
115+
111116
bool mlir::emitc::isSupportedFloatType(Type type) {
112117
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
113118
switch (floatType.getWidth()) {
@@ -818,12 +823,61 @@ LogicalResult emitc::YieldOp::verify() {
818823
//===----------------------------------------------------------------------===//
819824

820825
LogicalResult emitc::SubscriptOp::verify() {
821-
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
822-
return emitOpError() << "requires number of indices ("
823-
<< getIndices().size()
824-
<< ") to match the rank of the array type ("
825-
<< getArray().getType().getRank() << ")";
826+
// Checks for array operand.
827+
if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
828+
// Check number of indices.
829+
if (getIndices().size() != (size_t)arrayType.getRank()) {
830+
return emitOpError() << "on array operand requires number of indices ("
831+
<< getIndices().size()
832+
<< ") to match the rank of the array type ("
833+
<< arrayType.getRank() << ")";
834+
}
835+
// Check types of index operands.
836+
for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
837+
Type type = getIndices()[i].getType();
838+
if (!isIntegerIndexOrOpaqueType(type)) {
839+
return emitOpError() << "on array operand requires index operand " << i
840+
<< " to be integer-like, but got " << type;
841+
}
842+
}
843+
// Check element type.
844+
Type elementType = arrayType.getElementType();
845+
if (elementType != getType()) {
846+
return emitOpError() << "on array operand requires element type ("
847+
<< elementType << ") and result type (" << getType()
848+
<< ") to match";
849+
}
850+
return success();
826851
}
852+
853+
// Checks for pointer operand.
854+
if (auto pointerType =
855+
llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
856+
// Check number of indices.
857+
if (getIndices().size() != 1) {
858+
return emitOpError()
859+
<< "on pointer operand requires one index operand, but got "
860+
<< getIndices().size();
861+
}
862+
// Check types of index operand.
863+
Type type = getIndices()[0].getType();
864+
if (!isIntegerIndexOrOpaqueType(type)) {
865+
return emitOpError() << "on pointer operand requires index operand to be "
866+
"integer-like, but got "
867+
<< type;
868+
}
869+
// Check pointee type.
870+
Type pointeeType = pointerType.getPointee();
871+
if (pointeeType != getType()) {
872+
return emitOpError() << "on pointer operand requires pointee type ("
873+
<< pointeeType << ") and result type (" << getType()
874+
<< ") to match";
875+
}
876+
return success();
877+
}
878+
879+
// The operand has opaque type, so we can't assume anything about the number
880+
// or types of index operands.
827881
return success();
828882
}
829883

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
11201120
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
11211121
std::string out;
11221122
llvm::raw_string_ostream ss(out);
1123-
ss << getOrCreateName(op.getArray());
1123+
ss << getOrCreateName(op.getValue());
11241124
for (auto index : op.getIndices()) {
11251125
ss << "[" << getOrCreateName(index) << "]";
11261126
}

mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
66
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
77
%0 = memref.alloca() : memref<4x8xf32>
88

9-
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
9+
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
1010
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
1111
memref.store %v, %0[%i, %j] : memref<4x8xf32>
1212
return
@@ -20,7 +20,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
2020
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
2121
%0 = memref.alloca() : memref<4x8xf32>
2222

23-
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
23+
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
2424
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
2525
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
2626
%1 = memref.load %0[%i, %j] : memref<4x8xf32>

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,49 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
390390

391391
// -----
392392

393-
func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
394-
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
395-
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
393+
func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
394+
// expected-error @+1 {{'emitc.subscript' op on array operand requires number of indices (1) to match the rank of the array type (2)}}
395+
%0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
396+
return
397+
}
398+
399+
// -----
400+
401+
func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
402+
// expected-error @+1 {{'emitc.subscript' op on array operand requires index operand 1 to be integer-like, but got 'f32'}}
403+
%0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
404+
return
405+
}
406+
407+
// -----
408+
409+
func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
410+
// expected-error @+1 {{'emitc.subscript' op on array operand requires element type ('f32') and result type ('i32') to match}}
411+
%0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
412+
return
413+
}
414+
415+
// -----
416+
417+
func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
418+
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires one index operand, but got 2}}
419+
%0 = emitc.subscript %arg0[%arg1, %arg1] : (!emitc.ptr<f32>, index, index) -> f32
420+
return
421+
}
422+
423+
// -----
424+
425+
func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: f64) {
426+
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires index operand to be integer-like, but got 'f64'}}
427+
%0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, f64) -> f32
428+
return
429+
}
430+
431+
// -----
432+
433+
func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
434+
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires pointee type ('f32') and result type ('f64') to match}}
435+
%0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, index) -> f64
396436
return
397437
}
398438

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
214214
return
215215
}
216216

217+
func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
218+
%0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
219+
%1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
220+
%2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
221+
return
222+
}
223+
217224
emitc.verbatim "#ifdef __cplusplus"
218225
emitc.verbatim "extern \"C\" {"
219226
emitc.verbatim "#endif // __cplusplus"

mlir/test/Target/Cpp/subscript.mlir

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,44 @@
11
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
22
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
33

4-
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
5-
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
6-
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
4+
func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
5+
%0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
6+
%1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
77
emitc.assign %0 : f32 to %1 : f32
88
return
99
}
10-
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
10+
// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
1111
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
1212
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
1313

14+
func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
15+
%0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
16+
%1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
17+
emitc.assign %0 : f32 to %1 : f32
18+
return
19+
}
20+
// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
21+
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
22+
// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
23+
24+
func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
25+
%0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
26+
%1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
27+
emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
28+
return
29+
}
30+
// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
31+
// CHECK-SAME: char [[I:[^ ]*]], char [[J:[^ ]*]])
32+
// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
33+
1434
emitc.func @func1(%arg0 : f32) {
1535
emitc.return
1636
}
1737

1838
emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
1939
%k: i8) {
20-
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
21-
%1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
40+
%0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
41+
%1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
2242

2343
emitc.call @func1 (%0) : (f32) -> ()
2444
emitc.call_opaque "func2" (%1) : (f32) -> ()

0 commit comments

Comments
 (0)