Skip to content

Commit 3bc50bf

Browse files
committed
[MLIR] EmitC: Add subscript operator
Introduces a SubscriptOp that allows to write IR like ``` func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) { %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32> %1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32> emitc.assign %0 : f32 to %1 : f32 return } ``` which gets translated into the C++ code ``` v1[v2][v3] = v0[v1][v2]; ``` To make this happen, this - adds the SubscriptOp - allows the subscript op as rhs of emitc.assign - updates the emitter to print SubscriptOps The emitter prints emitc.subscript in a delayed fashing to allow it being used as lvalue. I.e. while processing ``` %0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32> ``` it will not emit any text, but record in the `valueMapper` that the name for `%0` is `v0[v1][v2]`, see `CppEmitter::getSubscriptName`. Only when that result is then used (here in `emitc.assign`), that name is inserted into the text.
1 parent 818af71 commit 3bc50bf

File tree

5 files changed

+103
-9
lines changed

5 files changed

+103
-9
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,4 +1125,36 @@ def EmitC_IfOp : EmitC_Op<"if",
11251125
let hasCustomAssemblyFormat = 1;
11261126
}
11271127

1128+
def EmitC_SubscriptOp : EmitC_Op<"subscript",
1129+
[TypesMatchWith<"result type matches element type of 'array'",
1130+
"array", "result",
1131+
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
1132+
let summary = "Array subscript operation";
1133+
let description = [{
1134+
With the `subscript` operation the subscript operator `[]` can be applied
1135+
to variables or arguments of array type.
1136+
1137+
Example:
1138+
1139+
```mlir
1140+
%i = index.constant 1
1141+
%j = index.constant 7
1142+
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
1143+
```
1144+
}];
1145+
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
1146+
Variadic<Index>:$indices);
1147+
let results = (outs AnyType:$result);
1148+
1149+
let builders = [
1150+
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
1151+
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
1152+
}]>
1153+
];
1154+
1155+
let hasVerifier = 1;
1156+
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
1157+
}
1158+
1159+
11281160
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ LogicalResult ApplyOp::verify() {
132132
LogicalResult emitc::AssignOp::verify() {
133133
Value variable = getVar();
134134
Operation *variableDef = variable.getDefiningOp();
135-
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
135+
if (!variableDef ||
136+
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
136137
return emitOpError() << "requires first operand (" << variable
137-
<< ") to be a Variable";
138+
<< ") to be a Variable or subscript";
138139

139140
Value value = getValue();
140141
if (variable.getType() != value.getType())
@@ -746,6 +747,20 @@ LogicalResult emitc::YieldOp::verify() {
746747
return success();
747748
}
748749

750+
//===----------------------------------------------------------------------===//
751+
// SubscriptOp
752+
//===----------------------------------------------------------------------===//
753+
754+
LogicalResult emitc::SubscriptOp::verify() {
755+
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
756+
return emitOpError() << "requires number of indices ("
757+
<< getIndices().size()
758+
<< ") to match the rank of the array type ("
759+
<< getArray().getType().getRank() << ")";
760+
}
761+
return success();
762+
}
763+
749764
//===----------------------------------------------------------------------===//
750765
// TableGen'd op method definitions
751766
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ struct CppEmitter {
171171
/// Return the existing or a new name for a Value.
172172
StringRef getOrCreateName(Value val);
173173

174+
// Returns the textual representation of a subscript operation.
175+
std::string getSubscriptName(emitc::SubscriptOp op);
176+
174177
/// Return the existing or a new label of a Block.
175178
StringRef getOrCreateName(Block &block);
176179

@@ -340,15 +343,21 @@ static LogicalResult printOperation(CppEmitter &emitter,
340343

341344
static LogicalResult printOperation(CppEmitter &emitter,
342345
emitc::AssignOp assignOp) {
343-
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
344-
OpResult result = variableOp->getResult(0);
346+
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
345347

346348
if (failed(emitter.emitVariableAssignment(result)))
347349
return failure();
348350

349351
return emitter.emitOperand(assignOp.getValue());
350352
}
351353

354+
static LogicalResult printOperation(CppEmitter &emitter,
355+
emitc::SubscriptOp subscriptOp) {
356+
// Add name to cache so that `hasValueInScope` works.
357+
emitter.getOrCreateName(subscriptOp.getResult());
358+
return success();
359+
}
360+
352361
static LogicalResult printBinaryOperation(CppEmitter &emitter,
353362
Operation *operation,
354363
StringRef binaryOperator) {
@@ -1067,12 +1076,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
10671076
labelInScopeCount.push(0);
10681077
}
10691078

1079+
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
1080+
std::string out;
1081+
llvm::raw_string_ostream ss(out);
1082+
ss << getOrCreateName(op.getArray());
1083+
for (auto index : op.getIndices()) {
1084+
ss << "[" << getOrCreateName(index) << "]";
1085+
}
1086+
return out;
1087+
}
1088+
10701089
/// Return the existing or a new name for a Value.
10711090
StringRef CppEmitter::getOrCreateName(Value val) {
10721091
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
10731092
return literal.getValue();
1074-
if (!valueMapper.count(val))
1075-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1093+
if (!valueMapper.count(val)) {
1094+
if (auto subscript =
1095+
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1096+
valueMapper.insert(val, getSubscriptName(subscript));
1097+
} else {
1098+
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1099+
}
1100+
}
10761101
return *valueMapper.begin(val);
10771102
}
10781103

@@ -1312,6 +1337,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
13121337

13131338
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
13141339
bool trailingSemicolon) {
1340+
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
1341+
return success();
13151342
if (hasValueInScope(result)) {
13161343
return result.getDefiningOp()->emitError(
13171344
"result variable for the operation already declared");
@@ -1387,8 +1414,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
13871414
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
13881415
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
13891416
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1390-
emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1391-
emitc::VariableOp, emitc::VerbatimOp>(
1417+
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1418+
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
13921419
[&](auto op) { return printOperation(*this, op); })
13931420
// Func ops.
13941421
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1401,7 +1428,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14011428
if (failed(status))
14021429
return failure();
14031430

1404-
if (isa<emitc::LiteralOp>(op))
1431+
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
14051432
return success();
14061433

14071434
if (getEmittedExpression() ||

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,11 @@ func.func @logical_or_resulterror(%arg0: i32, %arg1: i32) {
387387
%0 = "emitc.logical_or"(%arg0, %arg1) : (i32, i32) -> i32
388388
return
389389
}
390+
391+
// -----
392+
393+
func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
394+
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
395+
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
396+
return
397+
}

mlir/test/Target/Cpp/subscript.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
3+
4+
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
5+
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
6+
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
7+
emitc.assign %0 : f32 to %1 : f32
8+
return
9+
}
10+
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
11+
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
12+
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

0 commit comments

Comments
 (0)