-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][EmitC] Add support for pointer and opaque types to subscript op #86266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
For pointer types the indices are restricted to one integer-like operand. For opaque types no further restrictions are made.
@llvm/pr-subscribers-mlir Author: Simon Camphausen (simon-camp) ChangesFor pointer types the indices are restricted to one integer-like operand. Full diff: https://github.com/llvm/llvm-project/pull/86266.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bcb4e6cb1..d2f20b642b26b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerLikeType(Type type);
+
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..539a4f3e9805e1 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
- [TypesMatchWith<"result type matches element type of 'array'",
- "array", "result",
- "::llvm::cast<ArrayType>($_self).getElementType()">]> {
- let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+ let summary = "Subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
- to variables or arguments of array type.
+ to variables or arguments of array, pointer and opaque type.
Example:
```mlir
%i = index.constant 1
%j = index.constant 7
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+ %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+ %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
```
}];
- let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
- Variadic<IntegerIndexOrOpaqueType>:$indices);
+ let arguments = (ins Arg<AnyTypeOf<[
+ EmitC_ArrayType,
+ EmitC_OpaqueType,
+ EmitC_PointerType]>,
+ "the reference to load from">:$ref,
+ Variadic<AnyType>:$indices);
let results = (outs AnyType:$result);
let builders = [
- OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
- build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+ OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+ build($_builder, $_state, array.getType().getElementType(), array, indices);
+ }]>,
+ OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+ build($_builder, $_state, pointer.getType().getPointee(), pointer,
+ ValueRange{index});
}]>
];
let hasVerifier = 1;
- let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+ let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..3a2405a6195437 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+ operands.getIndices());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
@@ -83,7 +84,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
ConversionPatternRewriter &rewriter) const override {
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+ operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f364573552fe97 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}
+bool mlir::emitc::isIntegerLikeType(Type type) {
+ return isSupportedIntegerType(type) ||
+ llvm::isa<IndexType, emitc::OpaqueType>(type);
+}
+
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
@@ -781,11 +786,52 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult emitc::SubscriptOp::verify() {
- if (getIndices().size() != (size_t)getArray().getType().getRank()) {
- return emitOpError() << "requires number of indices ("
- << getIndices().size()
- << ") to match the rank of the array type ("
- << getArray().getType().getRank() << ")";
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
+ // Check arity of indices.
+ if (getIndices().size() != (size_t)arrayType.getRank()) {
+ return emitOpError() << "requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << arrayType.getRank() << ")";
+ }
+ // Check types of index operands.
+ for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+ Type type = getIndices()[i].getType();
+ if (!isIntegerLikeType(type)) {
+ return emitOpError() << "requires index operand " << i
+ << " to be integer-like, but got " << type;
+ }
+ }
+ // Check element type.
+ Type elementType = arrayType.getElementType();
+ if (elementType != getType()) {
+ return emitOpError() << "requires element type (" << elementType
+ << ") and result type (" << getType()
+ << ") to match";
+ }
+ } else if (auto pointerType =
+ llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
+ // Check arity of indices.
+ if (getIndices().size() != 1) {
+ return emitOpError() << "requires one index operand, but got "
+ << getIndices().size();
+ }
+ // Check types of index operand.
+ Type type = getIndices()[0].getType();
+ if (!isIntegerLikeType(type)) {
+ return emitOpError()
+ << "requires index operand to be integer-like, but got " << type;
+ }
+ // Check pointee type.
+ Type pointeeType = pointerType.getPointee();
+ if (pointeeType != getType()) {
+ return emitOpError() << "requires pointee type (" << pointeeType
+ << ") and result type (" << getType()
+ << ") to match";
+ }
+ } else {
+ // The reference has opaque type, so we can't assume anything about arity or
+ // types of index operands.
}
return success();
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..8fd04b7d1a51e0 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getArray());
+ ss << getOrCreateName(op.getRef());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..7aa2ba88843a2a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..321e4c01110e82 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
// -----
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
- %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+ %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+ // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+ %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+ // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
return
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..ace3670426afa5 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
return
}
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+ %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+ %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ return
+}
+
emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df9111a79..0b388953c80d37 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
- %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
- %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+ %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
emitc.assign %0 : f32 to %1 : f32
return
}
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+ emitc.assign %0 : f32 to %1 : f32
+ return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+ return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME: char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
emitc.func @func1(%arg0 : f32) {
emitc.return
}
emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
%k: i8) {
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
- %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+ %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+ %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
emitc.call @func1 (%0) : (f32) -> ()
emitc.call_opaque "func2" (%1) : (f32) -> ()
|
@llvm/pr-subscribers-mlir-emitc Author: Simon Camphausen (simon-camp) ChangesFor pointer types the indices are restricted to one integer-like operand. Full diff: https://github.com/llvm/llvm-project/pull/86266.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 725a1bcb4e6cb1..d2f20b642b26b2 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -30,8 +30,14 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+
/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);
+
+/// Determines whether \p type is integer like, i.e. it's a supported integer,
+/// an index or opaque type.
+bool isIntegerLikeType(Type type);
+
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..539a4f3e9805e1 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}
-def EmitC_SubscriptOp : EmitC_Op<"subscript",
- [TypesMatchWith<"result type matches element type of 'array'",
- "array", "result",
- "::llvm::cast<ArrayType>($_self).getElementType()">]> {
- let summary = "Array subscript operation";
+def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
+ let summary = "Subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
- to variables or arguments of array type.
+ to variables or arguments of array, pointer and opaque type.
Example:
```mlir
%i = index.constant 1
%j = index.constant 7
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
+ %0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
+ %1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
```
}];
- let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
- Variadic<IntegerIndexOrOpaqueType>:$indices);
+ let arguments = (ins Arg<AnyTypeOf<[
+ EmitC_ArrayType,
+ EmitC_OpaqueType,
+ EmitC_PointerType]>,
+ "the reference to load from">:$ref,
+ Variadic<AnyType>:$indices);
let results = (outs AnyType:$result);
let builders = [
- OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
- build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
+ OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
+ build($_builder, $_state, array.getType().getElementType(), array, indices);
+ }]>,
+ OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
+ build($_builder, $_state, pointer.getType().getPointee(), pointer,
+ ValueRange{index});
}]>
];
let hasVerifier = 1;
- let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
+ let assemblyFormat = "$ref `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..3a2405a6195437 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+ operands.getIndices());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
@@ -83,7 +84,8 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
ConversionPatternRewriter &rewriter) const override {
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()),
+ operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..f364573552fe97 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}
+bool mlir::emitc::isIntegerLikeType(Type type) {
+ return isSupportedIntegerType(type) ||
+ llvm::isa<IndexType, emitc::OpaqueType>(type);
+}
+
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
@@ -781,11 +786,52 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult emitc::SubscriptOp::verify() {
- if (getIndices().size() != (size_t)getArray().getType().getRank()) {
- return emitOpError() << "requires number of indices ("
- << getIndices().size()
- << ") to match the rank of the array type ("
- << getArray().getType().getRank() << ")";
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getRef().getType())) {
+ // Check arity of indices.
+ if (getIndices().size() != (size_t)arrayType.getRank()) {
+ return emitOpError() << "requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << arrayType.getRank() << ")";
+ }
+ // Check types of index operands.
+ for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+ Type type = getIndices()[i].getType();
+ if (!isIntegerLikeType(type)) {
+ return emitOpError() << "requires index operand " << i
+ << " to be integer-like, but got " << type;
+ }
+ }
+ // Check element type.
+ Type elementType = arrayType.getElementType();
+ if (elementType != getType()) {
+ return emitOpError() << "requires element type (" << elementType
+ << ") and result type (" << getType()
+ << ") to match";
+ }
+ } else if (auto pointerType =
+ llvm::dyn_cast<emitc::PointerType>(getRef().getType())) {
+ // Check arity of indices.
+ if (getIndices().size() != 1) {
+ return emitOpError() << "requires one index operand, but got "
+ << getIndices().size();
+ }
+ // Check types of index operand.
+ Type type = getIndices()[0].getType();
+ if (!isIntegerLikeType(type)) {
+ return emitOpError()
+ << "requires index operand to be integer-like, but got " << type;
+ }
+ // Check pointee type.
+ Type pointeeType = pointerType.getPointee();
+ if (pointeeType != getType()) {
+ return emitOpError() << "requires pointee type (" << pointeeType
+ << ") and result type (" << getType()
+ << ") to match";
+ }
+ } else {
+ // The reference has opaque type, so we can't assume anything about arity or
+ // types of index operands.
}
return success();
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..8fd04b7d1a51e0 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getArray());
+ ss << getOrCreateName(op.getRef());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..7aa2ba88843a2a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
@@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>
- // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
+ // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..321e4c01110e82 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
// -----
-func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
+func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
- %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
+ %0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
+ // expected-error @+1 {{'emitc.subscript' op requires index operand 1 to be integer-like, but got 'f32'}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires element type ('f32') and result type ('i32') to match}}
+ %0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires one index operand, but got 2}}
+ %0 = emitc.subscript %arg0[%arg2, %arg2] : (!emitc.ptr<f32>, index, index) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: f64) {
+ // expected-error @+1 {{'emitc.subscript' op requires index operand to be integer-like, but got 'f64'}}
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, f64) -> f32
+ return
+}
+
+// -----
+
+func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg2: index) {
+ // expected-error @+1 {{'emitc.subscript' op requires pointee type ('f32') and result type ('f64') to match}}
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f64
return
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..ace3670426afa5 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
return
}
+func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
+ %1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
+ %2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ return
+}
+
emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
diff --git a/mlir/test/Target/Cpp/subscript.mlir b/mlir/test/Target/Cpp/subscript.mlir
index a6c82df9111a79..0b388953c80d37 100644
--- a/mlir/test/Target/Cpp/subscript.mlir
+++ b/mlir/test/Target/Cpp/subscript.mlir
@@ -1,24 +1,44 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
-func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
- %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
- %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
+func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
+ %1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
emitc.assign %0 : f32 to %1 : f32
return
}
-// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
+// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
+func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
+ emitc.assign %0 : f32 to %1 : f32
+ return
+}
+// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
+// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
+// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];
+
+func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
+ %0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ %1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
+ emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
+ return
+}
+// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
+// CHECK-SAME: char [[I:[^ ]*]], char [[J:[^ ]*]])
+// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];
+
emitc.func @func1(%arg0 : f32) {
emitc.return
}
emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
%k: i8) {
- %0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
- %1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
+ %0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
+ %1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32
emitc.call @func1 (%0) : (f32) -> ()
emitc.call_opaque "func2" (%1) : (f32) -> ()
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I really like to see progress here. Some first comments, need to take a closer look after lunch.
@@ -63,7 +63,8 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> { | |||
} | |||
|
|||
auto subscript = rewriter.create<emitc::SubscriptOp>( | |||
op.getLoc(), operands.getMemref(), operands.getIndices()); | |||
op.getLoc(), cast<TypedValue<emitc::ArrayType>>(operands.getMemref()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we allow the builder to take Value
and leave the verification to the verifier? This is currently more verbose to write, and when it fails it would assert instead of being a verifier error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I switched to TypedValue to distinguish array and pointer values. These are the builders we have now:
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypedValue<ArrayType> array, ValueRange indices);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, TypedValue<PointerType> pointer, Value index);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value value, ::mlir::ValueRange indices);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value value, ::mlir::ValueRange indices);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
I will use the third one and pass the type explicitly.
Instead I now try a dyn_cast and fail gracefully.
llvm#86266) For pointer types the indices are restricted to one integer-like operand. For opaque types no further restrictions are made.
For pointer types the indices are restricted to one integer-like operand.
For opaque types no further restrictions are made.