Skip to content

Commit 126f037

Browse files
authored
Add a structured if operation (#67234)
Add an emitc.if op to the EmitC dialect. A new convert-scf-to-emitc pass replaces the existing direct translation of scf.if to C; The translator now handles emitc.if instead. The emitc.if op doesn't return any value and its then/else regions are terminated with a new scf.yield op. Values returned by scf.if are lowered using emitc.variable ops, assigned to in the then/else regions using a new emitc.assign op.
1 parent bd675f5 commit 126f037

File tree

15 files changed

+655
-42
lines changed

15 files changed

+655
-42
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ translating the following operations:
3333
* `func.return`
3434
* 'scf' Dialect
3535
* `scf.for`
36-
* `scf.if`
3736
* `scf.yield`
3837
* 'arith' Dialect
3938
* `arith.constant`

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
4949
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
5050
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
51+
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
5152
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
5253
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
5354
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,16 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
931931
let dependentDialects = ["affine::AffineDialect", "gpu::GPUDialect"];
932932
}
933933

934+
//===----------------------------------------------------------------------===//
935+
// SCFToEmitC
936+
//===----------------------------------------------------------------------===//
937+
938+
def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
939+
let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
940+
" control flow";
941+
let dependentDialects = ["emitc::EmitCDialect"];
942+
}
943+
934944
//===----------------------------------------------------------------------===//
935945
// ShapeToStandard
936946
//===----------------------------------------------------------------------===//
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- SCFToEmitC.h - SCF to EmitC Pass entrypoint --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
10+
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_SCFTOEMITC
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
22+
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
23+
24+
/// Creates a pass to convert SCF operations to the EmitC dialect.
25+
std::unique_ptr<Pass> createConvertSCFToEmitCPass();
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,23 @@
1414
#define MLIR_DIALECT_EMITC_IR_EMITC_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinOps.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/Dialect.h"
2021
#include "mlir/Interfaces/CastInterfaces.h"
22+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

2325
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
2426
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
2527

28+
namespace mlir {
29+
namespace emitc {
30+
void buildTerminatedBody(OpBuilder &builder, Location loc);
31+
} // namespace emitc
32+
} // namespace mlir
33+
2634
#define GET_ATTRDEF_CLASSES
2735
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"
2836

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td"
1717
include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
1818

1919
include "mlir/Interfaces/CastInterfaces.td"
20+
include "mlir/Interfaces/ControlFlowInterfaces.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
2122

2223
//===----------------------------------------------------------------------===//
@@ -402,4 +403,102 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
402403
let hasVerifier = 1;
403404
}
404405

406+
def EmitC_AssignOp : EmitC_Op<"assign", []> {
407+
let summary = "Assign operation";
408+
let description = [{
409+
The `assign` operation stores an SSA value to the location designated by an
410+
EmitC variable. This operation doesn't return any value. The assigned value
411+
must be of the same type as the variable being assigned. The operation is
412+
emitted as a C/C++ '=' operator.
413+
414+
Example:
415+
416+
```mlir
417+
// Integer variable
418+
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
419+
%1 = emitc.call "foo"() : () -> (i32)
420+
421+
// Assign emitted as `... = ...;`
422+
"emitc.assign"(%0, %1) : (i32, i32) -> ()
423+
```
424+
}];
425+
426+
let arguments = (ins AnyType:$var, AnyType:$value);
427+
let results = (outs);
428+
429+
let hasVerifier = 1;
430+
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
431+
}
432+
433+
def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
434+
let summary = "block termination operation";
435+
let description = [{
436+
"yield" terminates blocks within EmitC control-flow operations. Since
437+
control-flow constructs in C do not return values, this operation doesn't
438+
take any arguments.
439+
}];
440+
441+
let arguments = (ins);
442+
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
443+
444+
let assemblyFormat = [{ attr-dict }];
445+
}
446+
447+
def EmitC_IfOp : EmitC_Op<"if",
448+
[DeclareOpInterfaceMethods<RegionBranchOpInterface, [
449+
"getNumRegionInvocations", "getRegionInvocationBounds",
450+
"getEntrySuccessorRegions"]>, SingleBlock,
451+
SingleBlockImplicitTerminator<"emitc::YieldOp">,
452+
RecursiveMemoryEffects, NoRegionArguments]> {
453+
let summary = "if-then-else operation";
454+
let description = [{
455+
The `if` operation represents an if-then-else construct for
456+
conditionally executing two regions of code. The operand to an if operation
457+
is a boolean value. For example:
458+
459+
```mlir
460+
emitc.if %b {
461+
...
462+
} else {
463+
...
464+
}
465+
```
466+
467+
The "then" region has exactly 1 block. The "else" region may have 0 or 1
468+
blocks. The blocks are always terminated with `emitc.yield`, which can be
469+
left out to be inserted implicitly. This operation doesn't produce any
470+
results.
471+
}];
472+
let arguments = (ins I1:$condition);
473+
let results = (outs);
474+
let regions = (region SizedRegion<1>:$thenRegion,
475+
MaxSizedRegion<1>:$elseRegion);
476+
477+
let skipDefaultBuilders = 1;
478+
let builders = [
479+
OpBuilder<(ins "Value":$cond)>,
480+
OpBuilder<(ins "Value":$cond, "bool":$addThenBlock, "bool":$addElseBlock)>,
481+
OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
482+
OpBuilder<(ins "Value":$cond,
483+
CArg<"function_ref<void(OpBuilder &, Location)>",
484+
"buildTerminatedBody">:$thenBuilder,
485+
CArg<"function_ref<void(OpBuilder &, Location)>",
486+
"nullptr">:$elseBuilder)>,
487+
];
488+
489+
let extraClassDeclaration = [{
490+
OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
491+
Block* body = getBody(0);
492+
return OpBuilder::atBlockEnd(body, listener);
493+
}
494+
OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
495+
Block* body = getBody(1);
496+
return OpBuilder::atBlockEnd(body, listener);
497+
}
498+
Block* thenBlock();
499+
Block* elseBlock();
500+
}];
501+
let hasCustomAssemblyFormat = 1;
502+
}
503+
405504
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_subdirectory(OpenMPToLLVM)
3838
add_subdirectory(PDLToPDLInterp)
3939
add_subdirectory(ReconcileUnrealizedCasts)
4040
add_subdirectory(SCFToControlFlow)
41+
add_subdirectory(SCFToEmitC)
4142
add_subdirectory(SCFToGPU)
4243
add_subdirectory(SCFToOpenMP)
4344
add_subdirectory(SCFToSPIRV)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRSCFToEmitC
2+
SCFToEmitC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToEmitC
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRArithDialect
15+
MLIREmitCDialect
16+
MLIRSCFDialect
17+
MLIRTransforms
18+
)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert scf.if ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/SCF/IR/SCF.h"
18+
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinOps.h"
20+
#include "mlir/IR/IRMapping.h"
21+
#include "mlir/IR/MLIRContext.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/Transforms/DialectConversion.h"
24+
#include "mlir/Transforms/Passes.h"
25+
26+
namespace mlir {
27+
#define GEN_PASS_DEF_SCFTOEMITC
28+
#include "mlir/Conversion/Passes.h.inc"
29+
} // namespace mlir
30+
31+
using namespace mlir;
32+
using namespace mlir::scf;
33+
34+
namespace {
35+
36+
struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37+
void runOnOperation() override;
38+
};
39+
40+
// Lower scf::if to emitc::if, implementing return values as emitc::variable's
41+
// updated within the then and else regions.
42+
struct IfLowering : public OpRewritePattern<IfOp> {
43+
using OpRewritePattern<IfOp>::OpRewritePattern;
44+
45+
LogicalResult matchAndRewrite(IfOp ifOp,
46+
PatternRewriter &rewriter) const override;
47+
};
48+
49+
} // namespace
50+
51+
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
52+
PatternRewriter &rewriter) const {
53+
Location loc = ifOp.getLoc();
54+
55+
SmallVector<Value> resultVariables;
56+
57+
// Create an emitc::variable op for each result. These variables will be
58+
// assigned to by emitc::assign ops within the then & else regions.
59+
if (ifOp.getNumResults()) {
60+
MLIRContext *context = ifOp.getContext();
61+
rewriter.setInsertionPoint(ifOp);
62+
for (OpResult result : ifOp.getResults()) {
63+
Type resultType = result.getType();
64+
auto noInit = emitc::OpaqueAttr::get(context, "");
65+
auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
66+
resultVariables.push_back(var);
67+
}
68+
}
69+
70+
// Utility function to lower the contents of an scf::if region to an emitc::if
71+
// region. The contents of the scf::if regions is moved into the respective
72+
// emitc::if regions, but the scf::yield is replaced not only with an
73+
// emitc::yield, but also with a sequence of emitc::assign ops that set the
74+
// yielded values into the result variables.
75+
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
76+
Region &loweredRegion) {
77+
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
78+
Operation *terminator = loweredRegion.back().getTerminator();
79+
Location terminatorLoc = terminator->getLoc();
80+
ValueRange terminatorOperands = terminator->getOperands();
81+
rewriter.setInsertionPointToEnd(&loweredRegion.back());
82+
for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
83+
Value resultValue = std::get<0>(value2Var);
84+
Value resultVar = std::get<1>(value2Var);
85+
rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
86+
}
87+
rewriter.create<emitc::YieldOp>(terminatorLoc);
88+
rewriter.eraseOp(terminator);
89+
};
90+
91+
Region &thenRegion = ifOp.getThenRegion();
92+
Region &elseRegion = ifOp.getElseRegion();
93+
94+
bool hasElseBlock = !elseRegion.empty();
95+
96+
auto loweredIf =
97+
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
98+
99+
Region &loweredThenRegion = loweredIf.getThenRegion();
100+
lowerRegion(thenRegion, loweredThenRegion);
101+
102+
if (hasElseBlock) {
103+
Region &loweredElseRegion = loweredIf.getElseRegion();
104+
lowerRegion(elseRegion, loweredElseRegion);
105+
}
106+
107+
rewriter.replaceOp(ifOp, resultVariables);
108+
return success();
109+
}
110+
111+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
112+
patterns.add<IfLowering>(patterns.getContext());
113+
}
114+
115+
void SCFToEmitCPass::runOnOperation() {
116+
RewritePatternSet patterns(&getContext());
117+
populateSCFToEmitCConversionPatterns(patterns);
118+
119+
// Configure conversion to lower out SCF operations.
120+
ConversionTarget target(getContext());
121+
target.addIllegalOp<scf::IfOp>();
122+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
123+
if (failed(
124+
applyPartialConversion(getOperation(), target, std::move(patterns))))
125+
signalPassFailure();
126+
}
127+
128+
std::unique_ptr<Pass> mlir::createConvertSCFToEmitCPass() {
129+
return std::make_unique<SCFToEmitCPass>();
130+
}

0 commit comments

Comments
 (0)