Skip to content

[mlir][EmitC] Add support for pointer and opaque types to subscript op #86266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);

/// Determines whether \p type is a valid integer type in EmitC.
bool isSupportedIntegerType(mlir::Type type);

/// Determines whether \p type is integer like, i.e. it's a supported integer,
/// an index or opaque type.
bool isIntegerIndexOrOpaqueType(Type type);

/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
Expand Down
30 changes: 18 additions & 12 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1155,35 +1155,41 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}

def EmitC_SubscriptOp : EmitC_Op<"subscript",
[TypesMatchWith<"result type matches element type of 'array'",
"array", "result",
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
let summary = "Array subscript operation";
def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let summary = "Subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
to variables or arguments of array type.
to variables or arguments of array, pointer and opaque type.

Example:

```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
%0 = emitc.subscript %arg0[%i, %j] : !emitc.array<4x8xf32>, index, index
%1 = emitc.subscript %arg1[%i] : !emitc.ptr<i32>, index
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<IntegerIndexOrOpaqueType>:$indices);
let arguments = (ins Arg<AnyTypeOf<[
EmitC_ArrayType,
EmitC_OpaqueType,
EmitC_PointerType]>,
"the value to subscript">:$value,
Variadic<AnyType>:$indices);
let results = (outs AnyType:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
OpBuilder<(ins "TypedValue<ArrayType>":$array, "ValueRange":$indices), [{
build($_builder, $_state, array.getType().getElementType(), array, indices);
}]>,
OpBuilder<(ins "TypedValue<PointerType>":$pointer, "Value":$index), [{
build($_builder, $_state, pointer.getType().getPointee(), pointer,
ValueRange{index});
}]>
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}


Expand Down
15 changes: 13 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}

auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
op.getLoc(), arrayValue, operands.getIndices());

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
Expand All @@ -81,9 +87,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());
op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
Expand Down
64 changes: 59 additions & 5 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}

bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
isSupportedIntegerType(type);
}

bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
Expand Down Expand Up @@ -781,12 +786,61 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult emitc::SubscriptOp::verify() {
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
return emitOpError() << "requires number of indices ("
<< getIndices().size()
<< ") to match the rank of the array type ("
<< getArray().getType().getRank() << ")";
// Checks for array operand.
if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
// Check number of indices.
if (getIndices().size() != (size_t)arrayType.getRank()) {
return emitOpError() << "on array operand requires number of indices ("
<< getIndices().size()
<< ") to match the rank of the array type ("
<< arrayType.getRank() << ")";
}
// Check types of index operands.
for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
Type type = getIndices()[i].getType();
if (!isIntegerIndexOrOpaqueType(type)) {
return emitOpError() << "on array operand requires index operand " << i
<< " to be integer-like, but got " << type;
}
}
// Check element type.
Type elementType = arrayType.getElementType();
if (elementType != getType()) {
return emitOpError() << "on array operand requires element type ("
<< elementType << ") and result type (" << getType()
<< ") to match";
}
return success();
}

// Checks for pointer operand.
if (auto pointerType =
llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
// Check number of indices.
if (getIndices().size() != 1) {
return emitOpError()
<< "on pointer operand requires one index operand, but got "
<< getIndices().size();
}
// Check types of index operand.
Type type = getIndices()[0].getType();
if (!isIntegerIndexOrOpaqueType(type)) {
return emitOpError() << "on pointer operand requires index operand to be "
"integer-like, but got "
<< type;
}
// Check pointee type.
Type pointeeType = pointerType.getPointee();
if (pointeeType != getType()) {
return emitOpError() << "on pointer operand requires pointee type ("
<< pointeeType << ") and result type (" << getType()
<< ") to match";
}
return success();
}

// The operand has opaque type, so we can't assume anything about the number
// or types of index operands.
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getArray());
ss << getOrCreateName(op.getValue());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>

// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
Expand All @@ -19,7 +19,7 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
%0 = memref.alloca() : memref<4x8xf32>

// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
Expand Down
46 changes: 43 additions & 3 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,48 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {

// -----

func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
func.func @test_subscript_array_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index) {
// expected-error @+1 {{'emitc.subscript' op on array operand requires number of indices (1) to match the rank of the array type (2)}}
%0 = emitc.subscript %arg0[%arg1] : (!emitc.array<4x8xf32>, index) -> f32
return
}

// -----

func.func @test_subscript_array_index_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: f32) {
// expected-error @+1 {{'emitc.subscript' op on array operand requires index operand 1 to be integer-like, but got 'f32'}}
%0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, f32) -> f32
return
}

// -----

func.func @test_subscript_array_type_mismatch(%arg0: !emitc.array<4x8xf32>, %arg1: index, %arg2: index) {
// expected-error @+1 {{'emitc.subscript' op on array operand requires element type ('f32') and result type ('i32') to match}}
%0 = emitc.subscript %arg0[%arg1, %arg2] : (!emitc.array<4x8xf32>, index, index) -> i32
return
}

// -----

func.func @test_subscript_ptr_indices_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires one index operand, but got 2}}
%0 = emitc.subscript %arg0[%arg1, %arg1] : (!emitc.ptr<f32>, index, index) -> f32
return
}

// -----

func.func @test_subscript_ptr_index_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: f64) {
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires index operand to be integer-like, but got 'f64'}}
%0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, f64) -> f32
return
}

// -----

func.func @test_subscript_ptr_type_mismatch(%arg0: !emitc.ptr<f32>, %arg1: index) {
// expected-error @+1 {{'emitc.subscript' op on pointer operand requires pointee type ('f32') and result type ('f64') to match}}
%0 = emitc.subscript %arg0[%arg1] : (!emitc.ptr<f32>, index) -> f64
return
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
return
}

func.func @test_subscript(%arg0 : !emitc.array<2x3xf32>, %arg1 : !emitc.ptr<i32>, %arg2 : !emitc.opaque<"std::map<char, int>">, %idx0 : index, %idx1 : i32, %idx2 : !emitc.opaque<"char">) {
%0 = emitc.subscript %arg0[%idx0, %idx1] : (!emitc.array<2x3xf32>, index, i32) -> f32
%1 = emitc.subscript %arg1[%idx0] : (!emitc.ptr<i32>, index) -> i32
%2 = emitc.subscript %arg2[%idx2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
return
}

emitc.verbatim "#ifdef __cplusplus"
emitc.verbatim "extern \"C\" {"
emitc.verbatim "#endif // __cplusplus"
Expand Down
32 changes: 26 additions & 6 deletions mlir/test/Target/Cpp/subscript.mlir
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s

func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
func.func @load_store_array(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : (!emitc.array<4x8xf32>, index, index) -> f32
%1 = emitc.subscript %arg1[%arg2, %arg3] : (!emitc.array<3x5xf32>, index, index) -> f32
emitc.assign %0 : f32 to %1 : f32
return
}
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK: void load_store_array(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

func.func @load_store_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2] : (!emitc.ptr<f32>, index) -> f32
%1 = emitc.subscript %arg1[%arg3] : (!emitc.ptr<f32>, index) -> f32
emitc.assign %0 : f32 to %1 : f32
return
}
// CHECK: void load_store_pointer(float* [[PTR1:[^ ]*]], float* [[PTR2:[^ ]*]],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[PTR2]][[[J]]] = [[PTR1]][[[I]]];

func.func @load_store_opaque(%arg0: !emitc.opaque<"std::map<char, int>">, %arg1: !emitc.opaque<"std::map<char, int>">, %arg2: !emitc.opaque<"char">, %arg3: !emitc.opaque<"char">) {
%0 = emitc.subscript %arg0[%arg2] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
%1 = emitc.subscript %arg1[%arg3] : (!emitc.opaque<"std::map<char, int>">, !emitc.opaque<"char">) -> !emitc.opaque<"int">
emitc.assign %0 : !emitc.opaque<"int"> to %1 : !emitc.opaque<"int">
return
}
// CHECK: void load_store_opaque(std::map<char, int> [[MAP1:[^ ]*]], std::map<char, int> [[MAP2:[^ ]*]],
// CHECK-SAME: char [[I:[^ ]*]], char [[J:[^ ]*]])
// CHECK-NEXT: [[MAP2]][[[J]]] = [[MAP1]][[[I]]];

emitc.func @func1(%arg0 : f32) {
emitc.return
}

emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
%k: i8) {
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
%1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
%0 = emitc.subscript %arg0[%i, %j] : (!emitc.array<4x8xf32>, i32, i16) -> f32
%1 = emitc.subscript %arg0[%j, %k] : (!emitc.array<4x8xf32>, i16, i8) -> f32

emitc.call @func1 (%0) : (f32) -> ()
emitc.call_opaque "func2" (%1) : (f32) -> ()
Expand Down