Skip to content

[mlir][emitc] Add op modelling C expressions #71631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/EmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
96 changes: 90 additions & 6 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/RegionKindInterface.td"

//===----------------------------------------------------------------------===//
// EmitC op definitions
Expand Down Expand Up @@ -247,6 +248,83 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let results = (outs FloatIntegerIndexOrOpaqueType);
}

def EmitC_ExpressionOp : EmitC_Op<"expression",
[HasOnlyGraphRegion, SingleBlockImplicitTerminator<"emitc::YieldOp">,
NoRegionArguments]> {
let summary = "Expression operation";
let description = [{
The `expression` operation returns a single SSA value which is yielded by
its single-basic-block region. The operation doesn't take any arguments.

As the operation is to be emitted as a C expression, the operations within
its body must form a single Def-Use tree of emitc ops whose result is
yielded by a terminating `yield`.

Example:

```mlir
%r = emitc.expression : () -> i32 {
%0 = emitc.add %a, %b : (i32, i32) -> i32
%1 = emitc.call "foo"(%0) : () -> i32
%2 = emitc.add %c, %d : (i32, i32) -> i32
%3 = emitc.mul %1, %2 : (i32, i32) -> i32
yield %3
}
```

May be emitted as

```c++
int32_t v7 = foo(v1 + v2) * (v3 + v4);
```

The operations allowed within expression body are emitc.add, emitc.apply,
emitc.call, emitc.cast, emitc.cmp, emitc.div, emitc.mul, emitc.rem and
emitc.sub.

When specified, the optional `do_not_inline` indicates that the expression is
to be emitted as seen above, i.e. as the rhs of an EmitC SSA value
definition. Otherwise, the expression may be emitted inline, i.e. directly
at its use.
}];

let arguments = (ins UnitAttr:$do_not_inline);
let results = (outs AnyType:$result);
let regions = (region SizedRegion<1>:$region);

let hasVerifier = 1;
let assemblyFormat = "attr-dict (`noinline` $do_not_inline^)? `:` type($result) $region";

let extraClassDeclaration = [{
static bool isCExpression(Operation &op) {
return isa<emitc::AddOp, emitc::ApplyOp, emitc::CallOpaqueOp,
emitc::CastOp, emitc::CmpOp, emitc::DivOp, emitc::MulOp,
emitc::RemOp, emitc::SubOp>(op);
}
bool hasSideEffects() {
auto predicate = [](Operation &op) {
assert(isCExpression(op) && "Expected a C expression");
// Conservatively assume calls to read and write memory.
if (isa<emitc::CallOpaqueOp>(op))
return true;
// De-referencing reads modifiable memory, address-taking has no
// side-effect.
auto applyOp = dyn_cast<emitc::ApplyOp>(op);
if (applyOp)
return applyOp.getApplicableOperator() == "*";
// Any operation using variables is assumed to have a side effect of
// reading memory mutable by emitc::assign ops.
return llvm::any_of(op.getOperands(), [](Value operand) {
Operation *def = operand.getDefiningOp();
return def && isa<emitc::VariableOp>(def);
});
};
return llvm::any_of(getRegion().front().without_terminator(), predicate);
};
Operation *getRootOp();
}];
}

def EmitC_ForOp : EmitC_Op<"for",
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
Expand Down Expand Up @@ -494,18 +572,24 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}

def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
control-flow constructs in C do not return values, this operation doesn't
take any arguments.
"yield" terminates its parent EmitC op's region, optionally yielding
an SSA value. The semantics of how the values are yielded is defined by the
parent operation.
If "yield" has an operand, the operand must match the parent operation's
result. If the parent operation defines no values, then the "emitc.yield"
may be left out in the custom syntax and the builders will insert one
implicitly. Otherwise, it has to be present in the syntax to indicate which
value is yielded.
}];

let arguments = (ins);
let arguments = (ins Optional<AnyType>:$result);
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];

let assemblyFormat = [{ attr-dict }];
let hasVerifier = 1;
let assemblyFormat = [{ attr-dict ($result^ `:` type($result))? }];
}

def EmitC_IfOp : EmitC_Op<"if",
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name EmitC)
add_public_tablegen_target(MLIREmitCTransformsIncGen)

add_mlir_doc(Passes EmitCPasses ./ -gen-pass-doc)
35 changes: 35 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_

#include "mlir/Pass/Pass.h"

namespace mlir {
namespace emitc {

//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//

/// Creates an instance of the C-style expressions forming pass.
std::unique_ptr<Pass> createFormExpressionsPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"

} // namespace emitc
} // namespace mlir

#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES_H_
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===-- Passes.td - pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
#define MLIR_DIALECT_EMITC_TRANSFORMS_PASSES

include "mlir/Pass/PassBase.td"

def FormExpressions : Pass<"form-expressions"> {
let summary = "Form C-style expressions from C-operator ops";
let description = [{
The pass wraps emitc ops modelling C operators in emitc.expression ops and
then folds single-use expressions into their users where possible.
}];
let constructor = "mlir::emitc::createFormExpressionsPass()";
let dependentDialects = ["emitc::EmitCDialect"];
}

#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- Transforms.h - EmitC transformations as patterns --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
namespace emitc {

//===----------------------------------------------------------------------===//
// Expression transforms
//===----------------------------------------------------------------------===//

ExpressionOp createExpression(Operation *op, OpBuilder &builder);

//===----------------------------------------------------------------------===//
// Populate functions
//===----------------------------------------------------------------------===//

/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);

} // namespace emitc
} // namespace mlir

#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TRANSFORMS_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/Func/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
Expand Down Expand Up @@ -86,6 +87,7 @@ inline void registerAllPasses() {
vector::registerVectorPasses();
arm_sme::registerArmSMEPasses();
arm_sve::registerArmSVEPasses();
emitc::registerEmitCPasses();

// Dialect pipelines
bufferization::registerBufferizationPipelines();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/EmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
61 changes: 61 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,50 @@ LogicalResult emitc::ConstantOp::verify() {

OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// ExpressionOp
//===----------------------------------------------------------------------===//

Operation *ExpressionOp::getRootOp() {
auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
Value yieldedValue = yieldOp.getResult();
Operation *rootOp = yieldedValue.getDefiningOp();
assert(rootOp && "Yielded value not defined within expression");
return rootOp;
}

LogicalResult ExpressionOp::verify() {
Type resultType = getResult().getType();
Region &region = getRegion();

Block &body = region.front();

if (!body.mightHaveTerminator())
return emitOpError("must yield a value at termination");

auto yield = cast<YieldOp>(body.getTerminator());
Value yieldResult = yield.getResult();

if (!yieldResult)
return emitOpError("must yield a value at termination");

Type yieldType = yieldResult.getType();

if (resultType != yieldType)
return emitOpError("requires yielded type to match return type");

for (Operation &op : region.front().without_terminator()) {
if (!isCExpression(op))
return emitOpError("contains an unsupported operation");
if (op.getNumResults() != 1)
return emitOpError("requires exactly one result for each operation");
if (!op.getResult(0).hasOneUse())
return emitOpError("requires exactly one use for each operation");
}

return success();
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -530,6 +574,23 @@ LogicalResult emitc::VariableOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//

LogicalResult emitc::YieldOp::verify() {
Value result = getResult();
Operation *containingOp = getOperation()->getParentOp();

if (result && containingOp->getNumResults() != 1)
return emitOpError() << "yields a value not returned by parent";

if (!result && containingOp->getNumResults() != 0)
return emitOpError() << "does not yield a value to be returned by parent";

return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_mlir_dialect_library(MLIREmitCTransforms
Transforms.cpp
FormExpressions.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms

DEPENDS
MLIREmitCTransformsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIREmitCDialect
MLIRTransforms
)
60 changes: 60 additions & 0 deletions mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===- FormExpressions.cpp - Form C-style expressions --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that forms EmitC operations modeling C operators
// into C-style expressions using the emitc.expression op.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace emitc {
#define GEN_PASS_DEF_FORMEXPRESSIONS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
} // namespace emitc
} // namespace mlir

using namespace mlir;
using namespace emitc;

namespace {
struct FormExpressionsPass
: public emitc::impl::FormExpressionsBase<FormExpressionsPass> {
void runOnOperation() override {
Operation *rootOp = getOperation();
MLIRContext *context = rootOp->getContext();

// Wrap each C operator op with an expression op.
OpBuilder builder(context);
auto matchFun = [&](Operation *op) {
if (emitc::ExpressionOp::isCExpression(*op))
createExpression(op, builder);
};
rootOp->walk(matchFun);

// Fold expressions where possible.
RewritePatternSet patterns(context);
populateExpressionPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns))))
return signalPassFailure();
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<emitc::EmitCDialect>();
}
};
} // namespace

std::unique_ptr<Pass> mlir::emitc::createFormExpressionsPass() {
return std::make_unique<FormExpressionsPass>();
}
Loading