Skip to content

Commit c938fb8

Browse files
authored
[MLIR] EmitC: Add subscript operator
Reviewers: TinaAMD Reviewed By: TinaAMD Pull Request: #118
1 parent d3c90a4 commit c938fb8

File tree

6 files changed

+110
-8
lines changed

6 files changed

+110
-8
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,4 +565,36 @@ def EmitC_IfOp : EmitC_Op<"if",
565565
let hasCustomAssemblyFormat = 1;
566566
}
567567

568+
def EmitC_SubscriptOp : EmitC_Op<"subscript",
569+
[TypesMatchWith<"result type matches element type of 'array'",
570+
"array", "result",
571+
"::llvm::cast<ArrayType>($_self).getElementType()">]> {
572+
let summary = "Array subscript operation";
573+
let description = [{
574+
With the `subscript` operation the subscript operator `[]` can be applied
575+
to variables or arguments of array type.
576+
577+
Example:
578+
579+
```mlir
580+
%i = index.constant 1
581+
%j = index.constant 7
582+
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
583+
```
584+
}];
585+
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
586+
Variadic<Index>:$indices);
587+
let results = (outs AnyType:$result);
588+
589+
let builders = [
590+
OpBuilder<(ins "Value":$array, "ValueRange":$indices), [{
591+
build($_builder, $_state, cast<ArrayType>(array.getType()).getElementType(), array, indices);
592+
}]>
593+
];
594+
595+
let hasVerifier = 1;
596+
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
597+
}
598+
599+
568600
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,10 @@ LogicalResult ApplyOp::verify() {
102102
LogicalResult emitc::AssignOp::verify() {
103103
Value variable = getVar();
104104
Operation *variableDef = variable.getDefiningOp();
105-
if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
105+
if (!variableDef ||
106+
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
106107
return emitOpError() << "requires first operand (" << variable
107-
<< ") to be a Variable";
108+
<< ") to be a Variable or subscript";
108109

109110
Value value = getValue();
110111
if (variable.getType() != value.getType())
@@ -530,6 +531,20 @@ LogicalResult emitc::VariableOp::verify() {
530531
return success();
531532
}
532533

534+
//===----------------------------------------------------------------------===//
535+
// SubscriptOp
536+
//===----------------------------------------------------------------------===//
537+
538+
LogicalResult emitc::SubscriptOp::verify() {
539+
if (getIndices().size() != (size_t)getArray().getType().getRank()) {
540+
return emitOpError() << "requires number of indices ("
541+
<< getIndices().size()
542+
<< ") to match the rank of the array type ("
543+
<< getArray().getType().getRank() << ")";
544+
}
545+
return success();
546+
}
547+
533548
//===----------------------------------------------------------------------===//
534549
// TableGen'd op method definitions
535550
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ struct CppEmitter {
122122
/// Return the existing or a new name for a Value.
123123
StringRef getOrCreateName(Value val);
124124

125+
// Returns the textual representation of a subscript operation.
126+
std::string getSubscriptName(emitc::SubscriptOp op);
127+
125128
/// Return the existing or a new label of a Block.
126129
StringRef getOrCreateName(Block &block);
127130

@@ -251,8 +254,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
251254

252255
static LogicalResult printOperation(CppEmitter &emitter,
253256
emitc::AssignOp assignOp) {
254-
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
255-
OpResult result = variableOp->getResult(0);
257+
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
256258

257259
if (failed(emitter.emitVariableAssignment(result)))
258260
return failure();
@@ -262,6 +264,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
262264
return success();
263265
}
264266

267+
static LogicalResult printOperation(CppEmitter &emitter,
268+
emitc::SubscriptOp subscriptOp) {
269+
// Add name to cache so that `hasValueInScope` works.
270+
emitter.getOrCreateName(subscriptOp.getResult());
271+
return success();
272+
}
273+
265274
static LogicalResult printBinaryOperation(CppEmitter &emitter,
266275
Operation *operation,
267276
StringRef binaryOperator) {
@@ -706,12 +715,28 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
706715
labelInScopeCount.push(0);
707716
}
708717

718+
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
719+
std::string out;
720+
llvm::raw_string_ostream ss(out);
721+
ss << getOrCreateName(op.getArray());
722+
for (auto index : op.getIndices()) {
723+
ss << "[" << getOrCreateName(index) << "]";
724+
}
725+
return out;
726+
}
727+
709728
/// Return the existing or a new name for a Value.
710729
StringRef CppEmitter::getOrCreateName(Value val) {
711730
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
712731
return literal.getValue();
713-
if (!valueMapper.count(val))
714-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
732+
if (!valueMapper.count(val)) {
733+
if (auto subscript =
734+
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
735+
valueMapper.insert(val, getSubscriptName(subscript));
736+
} else {
737+
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
738+
}
739+
}
715740
return *valueMapper.begin(val);
716741
}
717742

@@ -891,6 +916,8 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
891916

892917
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
893918
bool trailingSemicolon) {
919+
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
920+
return success();
894921
if (hasValueInScope(result)) {
895922
return result.getDefiningOp()->emitError(
896923
"result variable for the operation already declared");
@@ -957,7 +984,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
957984
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
958985
emitc::ConstantOp, emitc::DivOp, emitc::ForOp, emitc::IfOp,
959986
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::SubOp,
960-
emitc::VariableOp>(
987+
emitc::SubscriptOp, emitc::VariableOp>(
961988
[&](auto op) { return printOperation(*this, op); })
962989
// Func ops.
963990
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
@@ -973,7 +1000,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
9731000
if (failed(status))
9741001
return failure();
9751002

976-
if (isa<emitc::LiteralOp>(op))
1003+
if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
9771004
return success();
9781005

9791006
os << (trailingSemicolon ? ";\n" : "\n");

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,11 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
224224
emitc.assign %arg1 : f32 to %v : i32
225225
return
226226
}
227+
228+
// -----
229+
230+
func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2: index) {
231+
// expected-error @+1 {{'emitc.subscript' op requires number of indices (1) to match the rank of the array type (2)}}
232+
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>
233+
return
234+
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,11 @@ func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
149149
}
150150
return
151151
}
152+
153+
func.func @test_subscript(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>,
154+
%arg2: index, %arg3: index) {
155+
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
156+
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
157+
emitc.assign %0 : f32 to %1 : f32
158+
return
159+
}

mlir/test/Target/Cpp/subscript.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
3+
4+
func.func @load_store(%arg0: !emitc.array<4x8xf32>, %arg1: !emitc.array<3x5xf32>, %arg2: index, %arg3: index) {
5+
%0 = emitc.subscript %arg0[%arg2, %arg3] : <4x8xf32>
6+
%1 = emitc.subscript %arg1[%arg2, %arg3] : <3x5xf32>
7+
emitc.assign %0 : f32 to %1 : f32
8+
return
9+
}
10+
// CHECK: void load_store(float [[ARR1:[^ ]*]][4][8], float [[ARR2:[^ ]*]][3][5],
11+
// CHECK-SAME: size_t [[I:[^ ]*]], size_t [[J:[^ ]*]])
12+
// CHECK-NEXT: [[ARR2]][[[I]]][[[J]]] = [[ARR1]][[[I]]][[[J]]];

0 commit comments

Comments
 (0)