Skip to content

Commit 7971655

Browse files
committed
[mlir] Add a generic while/do-while loop to the SCF dialect
The new construct represents a generic loop with two regions: one executed before the loop condition is verifier and another after that. This construct can be used to express both a "while" loop and a "do-while" loop, depending on where the main payload is located. It is intended as an intermediate abstraction for lowering, which will be added later. This form is relatively easy to target from higher-level abstractions and supports transformations such as loop rotation and LICM. Differential Revision: https://reviews.llvm.org/D90255
1 parent 3bec07f commit 7971655

File tree

7 files changed

+484
-38
lines changed

7 files changed

+484
-38
lines changed

mlir/include/mlir/Dialect/SCF/SCFOps.td

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,25 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
3636
let parser = [{ return ::parse$cppClass(parser, result); }];
3737
}
3838

39+
def ConditionOp : SCF_Op<"condition",
40+
[HasParent<"WhileOp">, NoSideEffect, Terminator]> {
41+
let summary = "loop continuation condition";
42+
let description = [{
43+
This operation accepts the continuation (i.e., inverse of exit) condition
44+
of the `scf.while` construct. If its first argument is true, the "after"
45+
region of `scf.while` is executed, with the remaining arguments forwarded
46+
to the entry block of the region. Otherwise, the loop terminates.
47+
}];
48+
49+
let arguments = (ins I1:$condition, Variadic<AnyType>:$args);
50+
51+
let assemblyFormat =
52+
[{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }];
53+
54+
// Override the default verifier, everything is checked by traits.
55+
let verifier = ?;
56+
}
57+
3958
def ForOp : SCF_Op<"for",
4059
[DeclareOpInterfaceMethods<LoopLikeOpInterface>,
4160
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -413,8 +432,135 @@ def ReduceReturnOp :
413432
let assemblyFormat = "$result attr-dict `:` type($result)";
414433
}
415434

435+
def WhileOp : SCF_Op<"while",
436+
[DeclareOpInterfaceMethods<RegionBranchOpInterface>,
437+
RecursiveSideEffects]> {
438+
let summary = "a generic 'while' loop";
439+
let description = [{
440+
This operation represents a generic "while"/"do-while" loop that keeps
441+
iterating as long as a condition is satisfied. There is no restriction on
442+
the complexity of the condition. It consists of two regions (with single
443+
block each): "before" region and "after" region. The names of regions
444+
indicates whether they execute before or after the condition check.
445+
Therefore, if the main loop payload is located in the "before" region, the
446+
operation is a "do-while" loop. Otherwise, it is a "while" loop.
447+
448+
The "before" region terminates with a special operation, `scf.condition`,
449+
that accepts as its first operand an `i1` value indicating whether to
450+
proceed to the "after" region (value is `true`) or not. The two regions
451+
communicate by means of region arguments. Initially, the "before" region
452+
accepts as arguments the operands of the `scf.while` operation and uses them
453+
to evaluate the condition. It forwards the trailing, non-condition operands
454+
of the `scf.condition` terminator either to the "after" region if the
455+
control flow is transferred there or to results of the `scf.while` operation
456+
otherwise. The "after" region takes as arguments the values produced by the
457+
"before" region and uses `scf.yield` to supply new arguments for the "after"
458+
region, into which it transfers the control flow unconditionally.
459+
460+
A simple "while" loop can be represented as follows.
461+
462+
```mlir
463+
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
464+
/* "Before" region.
465+
* In a "while" loop, this region computes the condition. */
466+
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
467+
468+
/* Forward the argument (as result or "after" region argument). */
469+
scf.condition(%condition) %arg1 : f32
470+
471+
} do {
472+
^bb0(%arg2: f32):
473+
/* "After region.
474+
* In a "while" loop, this region is the loop body. */
475+
%next = call @payload(%arg2) : (f32) -> f32
476+
477+
/* Forward the new value to the "before" region.
478+
* The operand types must match the types of the `scf.while` operands. */
479+
scf.yield %next : f32
480+
}
481+
```
482+
483+
A simple "do-while" loop can be represented by reducing the "after" block
484+
to a simple forwarder.
485+
486+
```mlir
487+
%res = scf.while (%arg1 = %init1) : (f32) -> f32 {
488+
/* "Before" region.
489+
* In a "do-while" loop, this region contains the loop body. */
490+
%next = call @payload(%arg1) : (f32) -> f32
491+
492+
/* And also evalutes the condition. */
493+
%condition = call @evaluate_condition(%arg1) : (f32) -> i1
494+
495+
/* Loop through the "after" region. */
496+
scf.condition(%condition) %next : f32
497+
498+
} do {
499+
^bb0(%arg2: f32):
500+
/* "After" region.
501+
* Forwards the values back to "before" region unmodified. */
502+
scf.yield %arg2 : f32
503+
}
504+
```
505+
506+
Note that the types of region arguments need not to match with each other.
507+
The op expects the operand types to match with argument types of the
508+
"before" region"; the result types to match with the trailing operand types
509+
of the terminator of the "before" region, and with the argument types of the
510+
"after" region. The following scheme can be used to share the results of
511+
some operations executed in the "before" region with the "after" region,
512+
avoiding the need to recompute them.
513+
514+
```mlir
515+
%res = scf.while (%arg1 = %init1) : (f32) -> i64 {
516+
/* One can perform some computations, e.g., necessary to evaluate the
517+
* condition, in the "before" region and forward their results to the
518+
* "after" region. */
519+
%shared = call @shared_compute(%arg1) : (f32) -> i64
520+
521+
/* Evalute the condition. */
522+
%condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1
523+
524+
/* Forward the result of the shared computation to the "after" region.
525+
* The types must match the arguments of the "after" region as well as
526+
* those of the `scf.while` results. */
527+
scf.condition(%condition) %shared : i64
528+
529+
} do {
530+
^bb0(%arg2: i64) {
531+
/* Use the partial result to compute the rest of the payload in the
532+
* "after" region. */
533+
%res = call @payload(%arg2) : (i64) -> f32
534+
535+
/* Forward the new value to the "before" region.
536+
* The operand types must match the types of the `scf.while` operands. */
537+
scf.yield %res : f32
538+
}
539+
```
540+
541+
The custom syntax for this operation is as follows.
542+
543+
```
544+
op ::= `scf.while` assignments `:` function-type region `do` region
545+
`attributes` attribute-dict
546+
initializer ::= /* empty */ | `(` assignment-list `)`
547+
assignment-list ::= assignment | assignment `,` assignment-list
548+
assignment ::= ssa-value `=` ssa-value
549+
```
550+
}];
551+
552+
let arguments = (ins Variadic<AnyType>:$inits);
553+
let results = (outs Variadic<AnyType>:$results);
554+
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
555+
556+
let extraClassDeclaration = [{
557+
OperandRange getSuccessorEntryOperands(unsigned index);
558+
}];
559+
}
560+
416561
def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
417-
ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> {
562+
ParentOneOf<["IfOp, ForOp", "ParallelOp",
563+
"WhileOp"]>]> {
418564
let summary = "loop yield and termination operation";
419565
let description = [{
420566
"scf.yield" yields an SSA value from the SCF dialect op region and
@@ -434,4 +580,5 @@ def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
434580
// needed.
435581
let verifier = ?;
436582
}
583+
437584
#endif // MLIR_DIALECT_SCF_SCFOPS

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -755,11 +755,18 @@ class OpAsmParser {
755755
parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
756756

757757
/// Parse a list of assignments of the form
758-
/// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
759-
/// The list must contain at least one entry
760-
virtual ParseResult
761-
parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
762-
SmallVectorImpl<OperandType> &rhs) = 0;
758+
/// (%x1 = %y1, %x2 = %y2, ...)
759+
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
760+
SmallVectorImpl<OperandType> &rhs) {
761+
OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
762+
if (!result.hasValue())
763+
return emitError(getCurrentLocation(), "expected '('");
764+
return result.getValue();
765+
}
766+
767+
virtual OptionalParseResult
768+
parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
769+
SmallVectorImpl<OperandType> &rhs) = 0;
763770

764771
/// Parse a keyword followed by a type.
765772
ParseResult parseKeywordType(const char *keyword, Type &result) {

0 commit comments

Comments
 (0)