Skip to content

[MLIR] EmitC: Add subscript operator #84783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1155,4 +1155,36 @@ def EmitC_IfOp : EmitC_Op<"if",
let hasCustomAssemblyFormat = 1;
}

def EmitC_SubscriptOp : EmitC_Op<"subscript",
[TypesMatchWith<"result type matches element type of 'array'",
"array", "result",
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
let summary = "Array subscript operation";
let description = [{
With the `subscript` operation the subscript operator `[]` can be applied
to variables or arguments of array type.

Example:

```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<IntegerIndexOrOpaqueType>:$indices);
let results = (outs AnyType:$result);

let builders = [
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
}]>
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
}


#endif // MLIR_DIALECT_EMITC_IR_EMITC
19 changes: 17 additions & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
LogicalResult emitc::AssignOp::verify() {
Value variable = getVar();
Operation *variableDef = variable.getDefiningOp();
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
if (!variableDef ||
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
return emitOpError() << "requires first operand (" << variable
<< ") to be a Variable";
<< ") to be a Variable or subscript";

Value value = getValue();
if (variable.getType() != value.getType())
Expand Down Expand Up @@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// SubscriptOp
//===----------------------------------------------------------------------===//

LogicalResult emitc::SubscriptOp::verify() {
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
return emitOpError() << "requires number of indices ("
<< getIndices().size()
<< ") to match the rank of the array type ("
<< getArray().getType().getRank() << ")";
}
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 33 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);

// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);

/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);

Expand Down Expand Up @@ -341,15 +344,21 @@ static LogicalResult printOperation(CppEmitter &emitter,

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
OpResult result = variableOp->getResult(0);
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);

if (failed(emitter.emitVariableAssignment(result)))
return failure();

return emitter.emitOperand(assignOp.getValue());
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
emitter.getOrCreateName(subscriptOp.getResult());
return success();
}

static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
Expand Down Expand Up @@ -1091,12 +1100,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
labelInScopeCount.push(0);
}

std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getArray());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
return out;
}

/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
return literal.getValue();
if (!valueMapper.count(val))
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
if (!valueMapper.count(val)) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
}
return *valueMapper.begin(val);
}

Expand Down Expand Up @@ -1336,6 +1361,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {

LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
bool trailingSemicolon) {
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
return success();
if (hasValueInScope(result)) {
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
Expand Down Expand Up @@ -1411,7 +1438,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
Expand All @@ -1426,7 +1453,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();

if (isa<emitc::LiteralOp>(op))
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
return success();

if (getEmittedExpression() ||
Expand Down
10 changes: 9 additions & 1 deletion mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
// -----

func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
emitc.assign %arg1 : f32 to %arg2 : f32
return
}
Expand Down Expand Up @@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
%0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
return
}

// -----

func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
return
}
32 changes: 32 additions & 0 deletions mlir/test/Target/Cpp/subscript.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s

func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
emitc.assign %0 : f32 to %1 : f32
return
}
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

emitc.func @func1(%arg0 : f32) {
emitc.return
}

emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
%k: i8) {
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
%1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8

emitc.call @func1 (%0) : (f32) -> ()
emitc.call_opaque "func2" (%1) : (f32) -> ()
emitc.call_opaque "func3" (%0, %1) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
emitc.return
}
// CHECK: void call_arg(float [[ARR1:[^ ]*]][4][8], int32_t [[I:[^ ]*]],
// CHECK-SAME: int16_t [[J:[^ ]*]], int8_t [[K:[^ ]*]])
// CHECK-NEXT: func1([[ARR1]][[[I]]][[[J]]]);
// CHECK-NEXT: func2([[ARR1]][[[J]]][[[K]]]);
// CHECK-NEXT: func3([[ARR1]][[[J]]][[[K]]], [[ARR1]][[[I]]][[[J]]]);