Skip to content

Add a structured if operation #67234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mlir/docs/Dialects/emitc.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ translating the following operations:
* `func.return`
* 'scf' Dialect
* `scf.for`
* `scf.if`
* `scf.yield`
* 'arith' Dialect
* `arith.constant`
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,16 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let dependentDialects = ["affine::AffineDialect", "gpu::GPUDialect"];
}

//===----------------------------------------------------------------------===//
// SCFToEmitC
//===----------------------------------------------------------------------===//

def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
" control flow";
let dependentDialects = ["emitc::EmitCDialect"];
}

//===----------------------------------------------------------------------===//
// ShapeToStandard
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
//===- SCFToEmitC.h - SCF to EmitC Pass entrypoint --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_

#include <memory>

namespace mlir {
class Pass;
class RewritePatternSet;

#define GEN_PASS_DECL_SCFTOEMITC
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);

/// Creates a pass to convert SCF operations to the EmitC dialect.
std::unique_ptr<Pass> createConvertSCFToEmitCPass();

} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@
#define MLIR_DIALECT_EMITC_IR_EMITC_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"

namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
} // namespace emitc
} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"

Expand Down
99 changes: 99 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td"
include "mlir/Dialect/EmitC/IR/EmitCTypes.td"

include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -402,4 +403,102 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}

def EmitC_AssignOp : EmitC_Op<"assign", []> {
let summary = "Assign operation";
let description = [{
The `assign` operation stores an SSA value to the location designated by an
EmitC variable. This operation doesn't return any value. The assigned value
must be of the same type as the variable being assigned. The operation is
emitted as a C/C++ '=' operator.

Example:

```mlir
// Integer variable
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
%1 = emitc.call "foo"() : () -> (i32)

// Assign emitted as `... = ...;`
"emitc.assign"(%0, %1) : (i32, i32) -> ()
```
}];

let arguments = (ins AnyType:$var, AnyType:$value);
let results = (outs);

let hasVerifier = 1;
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
}

def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
control-flow constructs in C do not return values, this operation doesn't
take any arguments.
}];

let arguments = (ins);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];

let assemblyFormat = [{ attr-dict }];
}

def EmitC_IfOp : EmitC_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
"getEntrySuccessorRegions"]>, SingleBlock,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
RecursiveMemoryEffects, NoRegionArguments]> {
let summary = "if-then-else operation";
let description = [{
The `if` operation represents an if-then-else construct for
conditionally executing two regions of code. The operand to an if operation
is a boolean value. For example:

```mlir
emitc.if %b {
...
} else {
...
}
```

The "then" region has exactly 1 block. The "else" region may have 0 or 1
blocks. The blocks are always terminated with `emitc.yield`, which can be
left out to be inserted implicitly. This operation doesn't produce any
results.
}];
let arguments = (ins I1:$condition);
let results = (outs);
let regions = (region SizedRegion<1>:$thenRegion,
MaxSizedRegion<1>:$elseRegion);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Value":$cond)>,
OpBuilder<(ins "Value":$cond, "bool":$addThenBlock, "bool":$addElseBlock)>,
OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
OpBuilder<(ins "Value":$cond,
CArg<"function_ref<void(OpBuilder &, Location)>",
"buildTerminatedBody">:$thenBuilder,
CArg<"function_ref<void(OpBuilder &, Location)>",
"nullptr">:$elseBuilder)>,
];

let extraClassDeclaration = [{
OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
Block* body = getBody(0);
return OpBuilder::atBlockEnd(body, listener);
}
OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
Block* body = getBody(1);
return OpBuilder::atBlockEnd(body, listener);
}
Block* thenBlock();
Block* elseBlock();
}];
let hasCustomAssemblyFormat = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_mlir_conversion_library(MLIRSCFToEmitC
SCFToEmitC.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIRSCFDialect
MLIRTransforms
)
130 changes: 130 additions & 0 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
//===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert scf.if ops into emitc ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
#define GEN_PASS_DEF_SCFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace mlir::scf;

namespace {

struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};

// Lower scf::if to emitc::if, implementing return values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
};

} // namespace

LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();

SmallVector<Value> resultVariables;

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
if (ifOp.getNumResults()) {
MLIRContext *context = ifOp.getContext();
rewriter.setInsertionPoint(ifOp);
for (OpResult result : ifOp.getResults()) {
Type resultType = result.getType();
auto noInit = emitc::OpaqueAttr::get(context, "");
auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
resultVariables.push_back(var);
}
}

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
Location terminatorLoc = terminator->getLoc();
ValueRange terminatorOperands = terminator->getOperands();
rewriter.setInsertionPointToEnd(&loweredRegion.back());
for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
Value resultValue = std::get<0>(value2Var);
Value resultVar = std::get<1>(value2Var);
rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
}
rewriter.create<emitc::YieldOp>(terminatorLoc);
rewriter.eraseOp(terminator);
};

Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();

bool hasElseBlock = !elseRegion.empty();

auto loweredIf =
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
}

rewriter.replaceOp(ifOp, resultVariables);
return success();
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<IfLowering>(patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateSCFToEmitCConversionPatterns(patterns);

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::IfOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

std::unique_ptr<Pass> mlir::createConvertSCFToEmitCPass() {
return std::make_unique<SCFToEmitCPass>();
}
Loading