Skip to content

Commit 01a31ce

Browse files
authored
[MLIR] EmitC: Add subscript operator (#84783)
Introduces a SubscriptOp that allows to write IR like ``` func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) { %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index emitc.assign %0 : f32 to %1 : f32 return } ``` which gets translated into the C++ code ``` v1[v2][v3] = v0[v1][v2]; ``` To make this happen, this - adds the SubscriptOp - allows the subscript op as rhs of emitc.assign - updates the emitter to print SubscriptOps The emitter prints emitc.subscript in a delayed fashing to allow it being used as lvalue. I.e. while processing ``` %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index ``` it will not emit any text, but record in the `valueMapper` that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when that result is then used (here in `emitc.assign`), that name is inserted into the text.
1 parent dbb2fd5 commit 01a31ce

File tree

5 files changed

+123
-9
lines changed

5 files changed

+123
-9
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,4 +1155,36 @@ def EmitC_IfOp : EmitC_Op<"if",
11551155
let hasCustomAssemblyFormat = 1;
11561156
}
11571157

1158+
def EmitC_SubscriptOp : EmitC_Op<"subscript",
1159+
[TypesMatchWith<"result type matches element type of 'array'",
1160+
"array", "result",
1161+
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
1162+
let summary = "Array subscript operation";
1163+
let description = [{
1164+
With the `subscript` operation the subscript operator `[]` can be applied
1165+
to variables or arguments of array type.
1166+
1167+
Example:
1168+
1169+
```mlir
1170+
%i = index.constant 1
1171+
%j = index.constant 7
1172+
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
1173+
```
1174+
}];
1175+
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
1176+
Variadic<IntegerIndexOrOpaqueType>:$indices);
1177+
let results = (outs AnyType:$result);
1178+
1179+
let builders = [
1180+
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
1181+
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
1182+
}]>
1183+
];
1184+
1185+
let hasVerifier = 1;
1186+
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
1187+
}
1188+
1189+
11581190
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
132132
LogicalResult emitc::AssignOp::verify() {
133133
Value variable = getVar();
134134
Operation *variableDef = variable.getDefiningOp();
135-
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
135+
if (!variableDef ||
136+
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
136137
return emitOpError() << "requires first operand (" << variable
137-
<< ") to be a Variable";
138+
<< ") to be a Variable or subscript";
138139

139140
Value value = getValue();
140141
if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
746747
return success();
747748
}
748749

750+
//===----------------------------------------------------------------------===//
751+
// SubscriptOp
752+
//===----------------------------------------------------------------------===//
753+
754+
LogicalResult emitc::SubscriptOp::verify() {
755+
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
756+
return emitOpError() << "requires number of indices ("
757+
<< getIndices().size()
758+
<< ") to match the rank of the array type ("
759+
<< getArray().getType().getRank() << ")";
760+
}
761+
return success();
762+
}
763+
749764
//===----------------------------------------------------------------------===//
750765
// TableGen'd op method definitions
751766
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ struct CppEmitter {
174174
/// Return the existing or a new name for a Value.
175175
StringRef getOrCreateName(Value val);
176176

177+
// Returns the textual representation of a subscript operation.
178+
std::string getSubscriptName(emitc::SubscriptOp op);
179+
177180
/// Return the existing or a new label of a Block.
178181
StringRef getOrCreateName(Block &block);
179182

@@ -343,15 +346,21 @@ static LogicalResult printOperation(CppEmitter &emitter,
343346

344347
static LogicalResult printOperation(CppEmitter &emitter,
345348
emitc::AssignOp assignOp) {
346-
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
347-
OpResult result = variableOp->getResult(0);
349+
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
348350

349351
if (failed(emitter.emitVariableAssignment(result)))
350352
return failure();
351353

352354
return emitter.emitOperand(assignOp.getValue());
353355
}
354356

357+
static LogicalResult printOperation(CppEmitter &emitter,
358+
emitc::SubscriptOp subscriptOp) {
359+
// Add name to cache so that `hasValueInScope` works.
360+
emitter.getOrCreateName(subscriptOp.getResult());
361+
return success();
362+
}
363+
355364
static LogicalResult printBinaryOperation(CppEmitter &emitter,
356365
Operation *operation,
357366
StringRef binaryOperator) {
@@ -1093,12 +1102,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
10931102
labelInScopeCount.push(0);
10941103
}
10951104

1105+
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1106+
std::string out;
1107+
llvm::raw_string_ostream ss(out);
1108+
ss << getOrCreateName(op.getArray());
1109+
for (auto index : op.getIndices()) {
1110+
ss << "[" << getOrCreateName(index) << "]";
1111+
}
1112+
return out;
1113+
}
1114+
10961115
/// Return the existing or a new name for a Value.
10971116
StringRef CppEmitter::getOrCreateName(Value val) {
10981117
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
10991118
return literal.getValue();
1100-
if (!valueMapper.count(val))
1101-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1119+
if (!valueMapper.count(val)) {
1120+
if (auto subscript =
1121+
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1122+
valueMapper.insert(val, getSubscriptName(subscript));
1123+
} else {
1124+
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1125+
}
1126+
}
11021127
return *valueMapper.begin(val);
11031128
}
11041129

@@ -1338,6 +1363,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
13381363

13391364
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
13401365
bool trailingSemicolon) {
1366+
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
1367+
return success();
13411368
if (hasValueInScope(result)) {
13421369
return result.getDefiningOp()->emitError(
13431370
"result variable for the operation already declared");
@@ -1413,7 +1440,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14131440
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
14141441
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
14151442
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1416-
emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1443+
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
14171444
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
14181445
emitc::VerbatimOp>(
14191446
[&](auto op) { return printOperation(*this, op); })
@@ -1428,7 +1455,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14281455
if (failed(status))
14291456
return failure();
14301457

1431-
if (isa<emitc::LiteralOp>(op))
1458+
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
14321459
return success();
14331460

14341461
if (getEmittedExpression() ||

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
235235
// -----
236236

237237
func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
238-
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
238+
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
239239
emitc.assign %arg1 : f32 to %arg2 : f32
240240
return
241241
}
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
387387
%0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
388388
return
389389
}
390+
391+
// -----
392+
393+
func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
394+
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
395+
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
396+
return
397+
}

mlir/test/Target/Cpp/subscript.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
3+
4+
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
5+
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>, index, index
6+
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>, index, index
7+
emitc.assign %0 : f32 to %1 : f32
8+
return
9+
}
10+
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
11+
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
12+
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];
13+
14+
emitc.func @func1(%arg0 : f32) {
15+
emitc.return
16+
}
17+
18+
emitc.func @call_arg(%arg0: !emitc.array<4x8xf32>, %i: i32, %j: i16,
19+
%k: i8) {
20+
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, i32, i16
21+
%1 = emitc.subscript %arg0[%j, %k] : <4x8xf32>, i16, i8
22+
23+
emitc.call @func1 (%0) : (f32) -> ()
24+
emitc.call_opaque "func2" (%1) : (f32) -> ()
25+
emitc.call_opaque "func3" (%0, %1) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
26+
emitc.return
27+
}
28+
// CHECK: void call_arg(float [[ARR1:[^ ]*]][4][8], int32_t [[I:[^ ]*]],
29+
// CHECK-SAME: int16_t [[J:[^ ]*]], int8_t [[K:[^ ]*]])
30+
// CHECK-NEXT: func1([[ARR1]][[[I]]][[[J]]]);
31+
// CHECK-NEXT: func2([[ARR1]][[[J]]][[[K]]]);
32+
// CHECK-NEXT: func3([[ARR1]][[[J]]][[[K]]], [[ARR1]][[[I]]][[[J]]]);

0 commit comments

Comments
 (0)