Skip to content

Commit 2ae5ded

Browse files
authored
[mlir][tosa] Update ControlFlow variable names to match with TOSA v1.0 spec (#129790)
1 parent 91aac7c commit 2ae5ded

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2471,12 +2471,12 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
24712471
}];
24722472

24732473
let arguments = (ins
2474-
Tosa_I1Tensor:$cond,
2475-
Variadic<Tosa_Tensor>:$inputs
2474+
Tosa_I1Tensor:$condition,
2475+
Variadic<Tosa_Tensor>:$input_list
24762476
);
24772477

24782478
let results = (outs
2479-
Variadic<Tosa_Tensor>:$output
2479+
Variadic<Tosa_Tensor>:$output_list
24802480
);
24812481

24822482
list<Availability> availability = [
@@ -2485,8 +2485,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
24852485
];
24862486

24872487
let regions = (region
2488-
SizedRegion<1>:$then_branch,
2489-
SizedRegion<1>:$else_branch
2488+
SizedRegion<1>:$then_graph,
2489+
SizedRegion<1>:$else_graph
24902490
);
24912491

24922492
let hasCustomAssemblyFormat = 1;
@@ -2513,11 +2513,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
25132513
}];
25142514

25152515
let arguments = (ins
2516-
Variadic<Tosa_Tensor>:$inputs
2516+
Variadic<Tosa_Tensor>:$input_list
25172517
);
25182518

25192519
let results = (outs
2520-
Variadic<Tosa_Tensor>:$output
2520+
Variadic<Tosa_Tensor>:$output_list
25212521
);
25222522

25232523
list<Availability> availability = [
@@ -2526,8 +2526,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
25262526
];
25272527

25282528
let regions = (region
2529-
SizedRegion<1>:$cond,
2530-
SizedRegion<1>:$body
2529+
SizedRegion<1>:$cond_graph,
2530+
SizedRegion<1>:$body_graph
25312531
);
25322532

25332533
let hasCustomAssemblyFormat = 1;

mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
6868
LogicalResult matchAndRewrite(tosa::IfOp op,
6969
PatternRewriter &rewriter) const final {
7070
auto condition =
71-
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
71+
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
7272
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
7373
condition, true);
7474

75-
inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
75+
inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
7676
rewriter);
77-
inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
77+
inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
7878
rewriter);
7979

8080
rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
158158
LogicalResult matchAndRewrite(tosa::WhileOp op,
159159
PatternRewriter &rewriter) const final {
160160
auto newWhile = rewriter.create<scf::WhileOp>(
161-
op.getLoc(), op.getResultTypes(), op.getInputs());
161+
op.getLoc(), op.getResultTypes(), op.getInputList());
162162
rewriter.createBlock(&newWhile.getBefore());
163163
rewriter.createBlock(&newWhile.getAfter());
164164

165-
inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
166-
inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
165+
inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
166+
inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
167167

168168
rewriter.replaceOp(op, newWhile.getResults());
169169

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
127127
//===----------------------------------------------------------------------===//
128128

129129
/// Returns the while loop body.
130-
SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
130+
SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131+
return {&getBodyGraph()};
132+
}
131133

132134
//===----------------------------------------------------------------------===//
133135
// Tosa dialect initialization.
@@ -2536,7 +2538,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
25362538
WhileOp::Adaptor adaptor,
25372539
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
25382540
llvm::SmallVector<tosa::YieldOp> yieldOps;
2539-
for (auto &block : adaptor.getBody())
2541+
for (auto &block : adaptor.getBodyGraph())
25402542
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
25412543
yieldOps.push_back(returnOp);
25422544

@@ -2616,19 +2618,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
26162618
void IfOp::print(OpAsmPrinter &p) {
26172619
bool printBlockTerminators = false;
26182620

2619-
p << " " << getCond();
2621+
p << " " << getCondition();
26202622
if (!getResults().empty()) {
26212623
p << " -> (" << getResultTypes() << ")";
26222624
// Print yield explicitly if the op defines values.
26232625
printBlockTerminators = true;
26242626
}
26252627
p << ' ';
2626-
p.printRegion(getThenBranch(),
2628+
p.printRegion(getThenGraph(),
26272629
/*printEntryBlockArgs=*/false,
26282630
/*printBlockTerminators=*/printBlockTerminators);
26292631

26302632
// Print the 'else' regions if it exists and has a block.
2631-
auto &elseRegion = getElseBranch();
2633+
auto &elseRegion = getElseGraph();
26322634
if (!elseRegion.empty()) {
26332635
p << " else ";
26342636
p.printRegion(elseRegion,
@@ -2726,14 +2728,15 @@ static void printInitializationList(OpAsmPrinter &parser,
27262728
}
27272729

27282730
void WhileOp::print(OpAsmPrinter &parser) {
2729-
printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2730-
" ");
2731+
printInitializationList(parser, getCondGraph().front().getArguments(),
2732+
getInputList(), " ");
27312733
parser << " : ";
2732-
parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2734+
parser.printFunctionalType(getInputList().getTypes(),
2735+
getResults().getTypes());
27332736
parser << ' ';
2734-
parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2737+
parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
27352738
parser << " do ";
2736-
parser.printRegion(getBody());
2739+
parser.printRegion(getBodyGraph());
27372740
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
27382741
}
27392742

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,14 +371,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
371371
}
372372
}
373373
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
374-
if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
375-
!levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
374+
if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
375+
!levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
376376
return false;
377377
}
378378
}
379379
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
380-
if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
381-
!levelCheckListSize(op, w.getOutput().size(), "outputs")) {
380+
if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
381+
!levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
382382
return false;
383383
}
384384
}
@@ -450,7 +450,7 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
450450
auto op = tosaOp.getOperation();
451451

452452
// Only the condition input has rank limitation.
453-
if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK))
453+
if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
454454
return false;
455455

456456
return true;

0 commit comments

Comments
 (0)